Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a2151b96
"docs/vscode:/vscode.git/clone" did not exist on "1ccbfbb663399d3aa363af513e5a8352f3afdb35"
Unverified
Commit
a2151b96
authored
Nov 03, 2022
by
Philip Meier
Committed by
GitHub
Nov 03, 2022
Browse files
replace assert torch.allclose with torch.testing.assert_allclose (#6895)
parent
79ca506c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
5 deletions
+5
-5
test/test_architecture_ops.py
test/test_architecture_ops.py
+2
-2
test/test_ops.py
test/test_ops.py
+3
-3
No files found.
test/test_architecture_ops.py
View file @
a2151b96
...
@@ -20,7 +20,7 @@ class MaxvitTester(unittest.TestCase):
...
@@ -20,7 +20,7 @@ class MaxvitTester(unittest.TestCase):
x_hat
=
partition
(
x
,
partition_size
)
x_hat
=
partition
(
x
,
partition_size
)
x_hat
=
departition
(
x_hat
,
partition_size
,
n_partitions
,
n_partitions
)
x_hat
=
departition
(
x_hat
,
partition_size
,
n_partitions
,
n_partitions
)
assert
torch
.
all
close
(
x
,
x_hat
)
torch
.
testing
.
assert_
close
(
x
,
x_hat
)
def
test_maxvit_grid_partition
(
self
):
def
test_maxvit_grid_partition
(
self
):
input_shape
=
(
1
,
3
,
224
,
224
)
input_shape
=
(
1
,
3
,
224
,
224
)
...
@@ -39,7 +39,7 @@ class MaxvitTester(unittest.TestCase):
...
@@ -39,7 +39,7 @@ class MaxvitTester(unittest.TestCase):
x_hat
=
post_swap
(
x_hat
)
x_hat
=
post_swap
(
x_hat
)
x_hat
=
departition
(
x_hat
,
n_partitions
,
partition_size
,
partition_size
)
x_hat
=
departition
(
x_hat
,
n_partitions
,
partition_size
,
partition_size
)
assert
torch
.
all
close
(
x
,
x_hat
)
torch
.
testing
.
assert_
close
(
x
,
x_hat
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/test_ops.py
View file @
a2151b96
...
@@ -630,7 +630,7 @@ class TestNMS:
...
@@ -630,7 +630,7 @@ class TestNMS:
boxes
,
scores
=
self
.
_create_tensors_with_iou
(
1000
,
iou
)
boxes
,
scores
=
self
.
_create_tensors_with_iou
(
1000
,
iou
)
keep_ref
=
self
.
_reference_nms
(
boxes
,
scores
,
iou
)
keep_ref
=
self
.
_reference_nms
(
boxes
,
scores
,
iou
)
keep
=
ops
.
nms
(
boxes
,
scores
,
iou
)
keep
=
ops
.
nms
(
boxes
,
scores
,
iou
)
assert
torch
.
all
close
(
keep
,
keep_ref
)
,
err_msg
.
format
(
iou
)
torch
.
testing
.
assert_
close
(
keep
,
keep_ref
,
msg
=
err_msg
.
format
(
iou
)
)
def
test_nms_input_errors
(
self
):
def
test_nms_input_errors
(
self
):
with
pytest
.
raises
(
RuntimeError
):
with
pytest
.
raises
(
RuntimeError
):
...
@@ -661,7 +661,7 @@ class TestNMS:
...
@@ -661,7 +661,7 @@ class TestNMS:
keep
=
ops
.
nms
(
boxes
,
scores
,
iou
)
keep
=
ops
.
nms
(
boxes
,
scores
,
iou
)
qkeep
=
ops
.
nms
(
qboxes
,
qscores
,
iou
)
qkeep
=
ops
.
nms
(
qboxes
,
qscores
,
iou
)
assert
torch
.
all
close
(
qkeep
,
keep
)
,
err_msg
.
format
(
iou
)
torch
.
testing
.
assert_
close
(
qkeep
,
keep
,
msg
=
err_msg
.
format
(
iou
)
)
@
needs_cuda
@
needs_cuda
@
pytest
.
mark
.
parametrize
(
"iou"
,
(
0.2
,
0.5
,
0.8
))
@
pytest
.
mark
.
parametrize
(
"iou"
,
(
0.2
,
0.5
,
0.8
))
...
@@ -1237,7 +1237,7 @@ class TestIouBase:
...
@@ -1237,7 +1237,7 @@ class TestIouBase:
boxes2
=
gen_box
(
7
)
boxes2
=
gen_box
(
7
)
a
=
TestIouBase
.
_cartesian_product
(
boxes1
,
boxes2
,
target_fn
)
a
=
TestIouBase
.
_cartesian_product
(
boxes1
,
boxes2
,
target_fn
)
b
=
target_fn
(
boxes1
,
boxes2
)
b
=
target_fn
(
boxes1
,
boxes2
)
assert
torch
.
all
close
(
a
,
b
)
torch
.
testing
.
assert_
close
(
a
,
b
)
class
TestBoxIou
(
TestIouBase
):
class
TestBoxIou
(
TestIouBase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment