Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
10381dab
Commit
10381dab
authored
Jun 30, 2021
by
Daniel Povey
Browse files
More tests...
parent
5a8c1e3a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
81 additions
and
3 deletions
+81
-3
torch_integrated_conv/integrated_conv_cpu.cpp
torch_integrated_conv/integrated_conv_cpu.cpp
+1
-1
torch_integrated_conv/integrated_conv_cuda_kernel.cu
torch_integrated_conv/integrated_conv_cuda_kernel.cu
+3
-2
torch_integrated_conv/integrated_conv_test.py
torch_integrated_conv/integrated_conv_test.py
+77
-0
No files found.
torch_integrated_conv/integrated_conv_cpu.cpp
View file @
10381dab
...
@@ -52,7 +52,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
...
@@ -52,7 +52,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
for
(
int
kh
=
0
;
kh
<
kH
;
kh
++
)
{
for
(
int
kh
=
0
;
kh
<
kH
;
kh
++
)
{
int
src_h
=
h
+
kh
-
kH
/
2
;
int
src_h
=
h
+
kh
-
kH
/
2
;
for
(
int
kw
=
0
;
kw
<
kW
;
kw
++
)
{
for
(
int
kw
=
0
;
kw
<
kW
;
kw
++
)
{
int
src_w
=
h
+
k
h
-
k
H
/
2
;
int
src_w
=
w
+
k
w
-
k
W
/
2
;
scalar_t
src
=
0.0
;
scalar_t
src
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
src_h
)
<
static_cast
<
unsigned
int
>
(
H
)
&&
if
(
static_cast
<
unsigned
int
>
(
src_h
)
<
static_cast
<
unsigned
int
>
(
H
)
&&
static_cast
<
unsigned
int
>
(
src_w
)
<
static_cast
<
unsigned
int
>
(
W
))
static_cast
<
unsigned
int
>
(
src_w
)
<
static_cast
<
unsigned
int
>
(
W
))
...
...
torch_integrated_conv/integrated_conv_cuda_kernel.cu
View file @
10381dab
...
@@ -294,8 +294,9 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
...
@@ -294,8 +294,9 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
<<
"; patchH,patchW="
<<
patchH
<<
","
<<
"; patchH,patchW="
<<
patchH
<<
","
<<
patchW
<<
", num_blocks_patch="
<<
patchW
<<
", num_blocks_patch="
<<
num_blocks_patch
<<
", num_blocks_batch="
<<
num_blocks_patch
<<
", num_blocks_batch="
<<
num_blocks_batch
<<
std
::
endl
;
<<
num_blocks_batch
<<
std
::
endl
<<
", threads_per_opixel="
<<
threads_per_opixel
<<
", threads_per_block="
<<
threads_per_block
;
dim3
gridDim
(
C
,
num_blocks_patch
,
num_blocks_batch
);
dim3
gridDim
(
C
,
num_blocks_patch
,
num_blocks_batch
);
// blockDim is scalar, just threads_per_block.
// blockDim is scalar, just threads_per_block.
...
...
torch_integrated_conv/integrated_conv_test.py
View file @
10381dab
import
random
import
torch
import
torch
from
torch_integrated_conv
import
integrated_conv
from
torch_integrated_conv
import
integrated_conv
...
@@ -8,6 +9,9 @@ def test_integrated_conv_zeros():
...
@@ -8,6 +9,9 @@ def test_integrated_conv_zeros():
H
=
3
H
=
3
W
=
4
W
=
4
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
if
device
==
torch
.
device
(
'cuda:0'
)
and
not
torch
.
cuda
.
is_available
():
print
(
"Warning: torch not available, not testing this part."
)
continue
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
print
(
"device="
,
device
,
", dtype="
,
dtype
)
print
(
"device="
,
device
,
", dtype="
,
dtype
)
input
=
torch
.
zeros
(
N
,
2
*
C
,
H
,
W
,
device
=
device
,
dtype
=
dtype
)
input
=
torch
.
zeros
(
N
,
2
*
C
,
H
,
W
,
device
=
device
,
dtype
=
dtype
)
...
@@ -19,3 +23,76 @@ def test_integrated_conv_zeros():
...
@@ -19,3 +23,76 @@ def test_integrated_conv_zeros():
output_ref
=
torch
.
zeros
(
N
,
C
,
H
,
W
,
device
=
device
,
dtype
=
dtype
)
output_ref
=
torch
.
zeros
(
N
,
C
,
H
,
W
,
device
=
device
,
dtype
=
dtype
)
output
=
integrated_conv
(
input
,
pos_add
,
pos_mul
)
output
=
integrated_conv
(
input
,
pos_add
,
pos_mul
)
assert
torch
.
allclose
(
output
,
output_ref
)
assert
torch
.
allclose
(
output
,
output_ref
)
def
test_integrated_conv_compare
():
N
=
1
C
=
2
H
=
3
W
=
4
if
not
torch
.
cuda
.
is_available
():
print
(
"Warning: torch not available, not testing this part."
)
return
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
print
(
"dtype="
,
dtype
)
input
=
torch
.
ones
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
)
device
=
torch
.
device
(
'cuda:0'
)
input_cuda
=
input
.
to
(
device
)
kH
=
5
kW
=
5
pos_add
=
torch
.
ones
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_mul
=
torch
.
ones
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_add_cuda
=
pos_add
.
to
(
device
)
pos_mul_cuda
=
pos_mul
.
to
(
device
)
output
=
integrated_conv
(
input
,
pos_add
,
pos_mul
)
output_cuda
=
integrated_conv
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
print
(
"output = "
,
output
)
print
(
"output_cuda = "
,
output_cuda
)
assert
torch
.
allclose
(
output
,
output_cuda
.
to
(
torch
.
device
(
'cpu'
)))
def
test_integrated_conv_rand_compare
():
for
_
in
range
(
30
):
N
=
random
.
randint
(
1
,
256
)
C
=
random
.
randint
(
1
,
64
)
H
=
random
.
randint
(
1
,
128
)
W
=
random
.
randint
(
1
,
128
)
while
N
*
C
*
H
*
W
>
65535
:
if
N
>=
C
and
N
>=
H
and
N
>=
W
:
N
=
N
//
2
elif
C
>=
H
and
C
>=
W
:
C
=
C
//
2
elif
H
>=
W
:
H
=
H
//
2
else
:
W
=
W
//
2
if
not
torch
.
cuda
.
is_available
():
print
(
"Warning: torch not available, not testing this part."
)
return
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
print
(
"dtype="
,
dtype
)
input
=
torch
.
ones
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
)
device
=
torch
.
device
(
'cuda:0'
)
input_cuda
=
input
.
to
(
device
)
kH
=
random
.
randint
(
1
,
10
)
kW
=
random
.
randint
(
1
,
10
)
if
kH
%
2
==
0
:
kH
+=
1
if
kW
%
2
==
0
:
kW
+=
1
pos_add
=
torch
.
ones
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_mul
=
torch
.
ones
(
C
,
kH
,
kW
,
dtype
=
dtype
)
pos_add_cuda
=
pos_add
.
to
(
device
)
pos_mul_cuda
=
pos_mul
.
to
(
device
)
output
=
integrated_conv
(
input
,
pos_add
,
pos_mul
)
output_cuda
=
integrated_conv
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
print
(
"output = "
,
output
)
print
(
"output_cuda = "
,
output_cuda
)
assert
torch
.
allclose
(
output
,
output_cuda
.
to
(
torch
.
device
(
'cpu'
)))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment