Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
05a3941f
Unverified
Commit
05a3941f
authored
May 21, 2021
by
Nicolas Hug
Committed by
GitHub
May 21, 2021
Browse files
Use torch.testing.assert_close in test_datasets_samplers.py (#3874)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
c4685e81
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
10 deletions
+11
-10
test/test_datasets_samplers.py
test/test_datasets_samplers.py
+11
-10
No files found.
test/test_datasets_samplers.py
View file @
05a3941f
...
@@ -14,6 +14,7 @@ from torchvision.datasets.video_utils import VideoClips, unfold
...
@@ -14,6 +14,7 @@ from torchvision.datasets.video_utils import VideoClips, unfold
from
torchvision
import
get_video_backend
from
torchvision
import
get_video_backend
from
common_utils
import
get_tmp_dir
from
common_utils
import
get_tmp_dir
from
_assert_utils
import
assert_equal
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
...
@@ -47,8 +48,8 @@ class Tester(unittest.TestCase):
...
@@ -47,8 +48,8 @@ class Tester(unittest.TestCase):
indices
=
torch
.
tensor
(
list
(
iter
(
sampler
)))
indices
=
torch
.
tensor
(
list
(
iter
(
sampler
)))
videos
=
torch
.
div
(
indices
,
5
,
rounding_mode
=
'floor'
)
videos
=
torch
.
div
(
indices
,
5
,
rounding_mode
=
'floor'
)
v_idxs
,
count
=
torch
.
unique
(
videos
,
return_counts
=
True
)
v_idxs
,
count
=
torch
.
unique
(
videos
,
return_counts
=
True
)
self
.
assert
True
(
v_idxs
.
equal
(
torch
.
tensor
([
0
,
1
,
2
]))
)
assert
_equal
(
v_idxs
,
torch
.
tensor
([
0
,
1
,
2
]))
self
.
assert
True
(
count
.
equal
(
torch
.
tensor
([
3
,
3
,
3
]))
)
assert
_equal
(
count
,
torch
.
tensor
([
3
,
3
,
3
]))
def
test_random_clip_sampler_unequal
(
self
):
def
test_random_clip_sampler_unequal
(
self
):
with
get_list_of_videos
(
num_videos
=
3
,
sizes
=
[
10
,
25
,
25
])
as
video_list
:
with
get_list_of_videos
(
num_videos
=
3
,
sizes
=
[
10
,
25
,
25
])
as
video_list
:
...
@@ -64,8 +65,8 @@ class Tester(unittest.TestCase):
...
@@ -64,8 +65,8 @@ class Tester(unittest.TestCase):
indices
=
torch
.
tensor
(
indices
)
-
2
indices
=
torch
.
tensor
(
indices
)
-
2
videos
=
torch
.
div
(
indices
,
5
,
rounding_mode
=
'floor'
)
videos
=
torch
.
div
(
indices
,
5
,
rounding_mode
=
'floor'
)
v_idxs
,
count
=
torch
.
unique
(
videos
,
return_counts
=
True
)
v_idxs
,
count
=
torch
.
unique
(
videos
,
return_counts
=
True
)
self
.
assert
True
(
v_idxs
.
equal
(
torch
.
tensor
([
0
,
1
]))
)
assert
_equal
(
v_idxs
,
torch
.
tensor
([
0
,
1
]))
self
.
assert
True
(
count
.
equal
(
torch
.
tensor
([
3
,
3
]))
)
assert
_equal
(
count
,
torch
.
tensor
([
3
,
3
]))
def
test_uniform_clip_sampler
(
self
):
def
test_uniform_clip_sampler
(
self
):
with
get_list_of_videos
(
num_videos
=
3
,
sizes
=
[
25
,
25
,
25
])
as
video_list
:
with
get_list_of_videos
(
num_videos
=
3
,
sizes
=
[
25
,
25
,
25
])
as
video_list
:
...
@@ -75,9 +76,9 @@ class Tester(unittest.TestCase):
...
@@ -75,9 +76,9 @@ class Tester(unittest.TestCase):
indices
=
torch
.
tensor
(
list
(
iter
(
sampler
)))
indices
=
torch
.
tensor
(
list
(
iter
(
sampler
)))
videos
=
torch
.
div
(
indices
,
5
,
rounding_mode
=
'floor'
)
videos
=
torch
.
div
(
indices
,
5
,
rounding_mode
=
'floor'
)
v_idxs
,
count
=
torch
.
unique
(
videos
,
return_counts
=
True
)
v_idxs
,
count
=
torch
.
unique
(
videos
,
return_counts
=
True
)
self
.
assert
True
(
v_idxs
.
equal
(
torch
.
tensor
([
0
,
1
,
2
]))
)
assert
_equal
(
v_idxs
,
torch
.
tensor
([
0
,
1
,
2
]))
self
.
assert
True
(
count
.
equal
(
torch
.
tensor
([
3
,
3
,
3
]))
)
assert
_equal
(
count
,
torch
.
tensor
([
3
,
3
,
3
]))
self
.
assert
True
(
indices
.
equal
(
torch
.
tensor
([
0
,
2
,
4
,
5
,
7
,
9
,
10
,
12
,
14
]))
)
assert
_equal
(
indices
,
torch
.
tensor
([
0
,
2
,
4
,
5
,
7
,
9
,
10
,
12
,
14
]))
def
test_uniform_clip_sampler_insufficient_clips
(
self
):
def
test_uniform_clip_sampler_insufficient_clips
(
self
):
with
get_list_of_videos
(
num_videos
=
3
,
sizes
=
[
10
,
25
,
25
])
as
video_list
:
with
get_list_of_videos
(
num_videos
=
3
,
sizes
=
[
10
,
25
,
25
])
as
video_list
:
...
@@ -85,7 +86,7 @@ class Tester(unittest.TestCase):
...
@@ -85,7 +86,7 @@ class Tester(unittest.TestCase):
sampler
=
UniformClipSampler
(
video_clips
,
3
)
sampler
=
UniformClipSampler
(
video_clips
,
3
)
self
.
assertEqual
(
len
(
sampler
),
3
*
3
)
self
.
assertEqual
(
len
(
sampler
),
3
*
3
)
indices
=
torch
.
tensor
(
list
(
iter
(
sampler
)))
indices
=
torch
.
tensor
(
list
(
iter
(
sampler
)))
self
.
assert
True
(
indices
.
equal
(
torch
.
tensor
([
0
,
0
,
1
,
2
,
4
,
6
,
7
,
9
,
11
]))
)
assert
_equal
(
indices
,
torch
.
tensor
([
0
,
0
,
1
,
2
,
4
,
6
,
7
,
9
,
11
]))
def
test_distributed_sampler_and_uniform_clip_sampler
(
self
):
def
test_distributed_sampler_and_uniform_clip_sampler
(
self
):
with
get_list_of_videos
(
num_videos
=
3
,
sizes
=
[
25
,
25
,
25
])
as
video_list
:
with
get_list_of_videos
(
num_videos
=
3
,
sizes
=
[
25
,
25
,
25
])
as
video_list
:
...
@@ -100,7 +101,7 @@ class Tester(unittest.TestCase):
...
@@ -100,7 +101,7 @@ class Tester(unittest.TestCase):
)
)
indices
=
torch
.
tensor
(
list
(
iter
(
distributed_sampler_rank0
)))
indices
=
torch
.
tensor
(
list
(
iter
(
distributed_sampler_rank0
)))
self
.
assertEqual
(
len
(
distributed_sampler_rank0
),
6
)
self
.
assertEqual
(
len
(
distributed_sampler_rank0
),
6
)
self
.
assert
True
(
indices
.
equal
(
torch
.
tensor
([
0
,
2
,
4
,
10
,
12
,
14
]))
)
assert
_equal
(
indices
,
torch
.
tensor
([
0
,
2
,
4
,
10
,
12
,
14
]))
distributed_sampler_rank1
=
DistributedSampler
(
distributed_sampler_rank1
=
DistributedSampler
(
clip_sampler
,
clip_sampler
,
...
@@ -110,7 +111,7 @@ class Tester(unittest.TestCase):
...
@@ -110,7 +111,7 @@ class Tester(unittest.TestCase):
)
)
indices
=
torch
.
tensor
(
list
(
iter
(
distributed_sampler_rank1
)))
indices
=
torch
.
tensor
(
list
(
iter
(
distributed_sampler_rank1
)))
self
.
assertEqual
(
len
(
distributed_sampler_rank1
),
6
)
self
.
assertEqual
(
len
(
distributed_sampler_rank1
),
6
)
self
.
assert
True
(
indices
.
equal
(
torch
.
tensor
([
5
,
7
,
9
,
0
,
2
,
4
]))
)
assert
_equal
(
indices
,
torch
.
tensor
([
5
,
7
,
9
,
0
,
2
,
4
]))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment