Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7992eb5d
Unverified
Commit
7992eb5d
authored
Mar 10, 2021
by
Nicolas Hug
Committed by
GitHub
Mar 10, 2021
Browse files
simplify _get_script_fn (#3541)
Co-authored-by:
Francisco Massa
<
fvsmassa@gmail.com
>
parent
3428a7de
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
20 deletions
+8
-20
test/test_ops.py
test/test_ops.py
+8
-20
No files found.
test/test_ops.py
View file @
7992eb5d
...
@@ -135,11 +135,8 @@ class RoIPoolTester(RoIOpTester, unittest.TestCase):
...
@@ -135,11 +135,8 @@ class RoIPoolTester(RoIOpTester, unittest.TestCase):
return
ops
.
RoIPool
((
pool_h
,
pool_w
),
spatial_scale
)(
x
,
rois
)
return
ops
.
RoIPool
((
pool_h
,
pool_w
),
spatial_scale
)(
x
,
rois
)
def
get_script_fn
(
self
,
rois
,
pool_size
):
def
get_script_fn
(
self
,
rois
,
pool_size
):
@
torch
.
jit
.
script
scriped
=
torch
.
jit
.
script
(
ops
.
roi_pool
)
def
script_fn
(
input
,
rois
,
pool_size
):
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
# type: (Tensor, Tensor, int) -> Tensor
return
ops
.
roi_pool
(
input
,
rois
,
pool_size
,
1.0
)[
0
]
return
lambda
x
:
script_fn
(
x
,
rois
,
pool_size
)
def
expected_fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
def
expected_fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
device
=
None
,
dtype
=
torch
.
float64
):
device
=
None
,
dtype
=
torch
.
float64
):
...
@@ -177,11 +174,8 @@ class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
...
@@ -177,11 +174,8 @@ class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
return
ops
.
PSRoIPool
((
pool_h
,
pool_w
),
1
)(
x
,
rois
)
return
ops
.
PSRoIPool
((
pool_h
,
pool_w
),
1
)(
x
,
rois
)
def
get_script_fn
(
self
,
rois
,
pool_size
):
def
get_script_fn
(
self
,
rois
,
pool_size
):
@
torch
.
jit
.
script
scriped
=
torch
.
jit
.
script
(
ops
.
ps_roi_pool
)
def
script_fn
(
input
,
rois
,
pool_size
):
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
# type: (Tensor, Tensor, int) -> Tensor
return
ops
.
ps_roi_pool
(
input
,
rois
,
pool_size
,
1.0
)[
0
]
return
lambda
x
:
script_fn
(
x
,
rois
,
pool_size
)
def
expected_fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
def
expected_fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
device
=
None
,
dtype
=
torch
.
float64
):
device
=
None
,
dtype
=
torch
.
float64
):
...
@@ -257,11 +251,8 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
...
@@ -257,11 +251,8 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
sampling_ratio
=
sampling_ratio
,
aligned
=
aligned
)(
x
,
rois
)
sampling_ratio
=
sampling_ratio
,
aligned
=
aligned
)(
x
,
rois
)
def
get_script_fn
(
self
,
rois
,
pool_size
):
def
get_script_fn
(
self
,
rois
,
pool_size
):
@
torch
.
jit
.
script
scriped
=
torch
.
jit
.
script
(
ops
.
roi_align
)
def
script_fn
(
input
,
rois
,
pool_size
):
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
# type: (Tensor, Tensor, int) -> Tensor
return
ops
.
roi_align
(
input
,
rois
,
pool_size
,
1.0
)[
0
]
return
lambda
x
:
script_fn
(
x
,
rois
,
pool_size
)
def
expected_fn
(
self
,
in_data
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
aligned
=
False
,
def
expected_fn
(
self
,
in_data
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
aligned
=
False
,
device
=
None
,
dtype
=
torch
.
float64
):
device
=
None
,
dtype
=
torch
.
float64
):
...
@@ -315,11 +306,8 @@ class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
...
@@ -315,11 +306,8 @@ class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
sampling_ratio
=
sampling_ratio
)(
x
,
rois
)
sampling_ratio
=
sampling_ratio
)(
x
,
rois
)
def
get_script_fn
(
self
,
rois
,
pool_size
):
def
get_script_fn
(
self
,
rois
,
pool_size
):
@
torch
.
jit
.
script
scriped
=
torch
.
jit
.
script
(
ops
.
ps_roi_align
)
def
script_fn
(
input
,
rois
,
pool_size
):
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
# type: (Tensor, Tensor, int) -> Tensor
return
ops
.
ps_roi_align
(
input
,
rois
,
pool_size
,
1.0
)[
0
]
return
lambda
x
:
script_fn
(
x
,
rois
,
pool_size
)
def
expected_fn
(
self
,
in_data
,
rois
,
pool_h
,
pool_w
,
device
,
spatial_scale
=
1
,
def
expected_fn
(
self
,
in_data
,
rois
,
pool_h
,
pool_w
,
device
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
dtype
=
torch
.
float64
):
sampling_ratio
=-
1
,
dtype
=
torch
.
float64
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment