Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
4c073b09
Unverified
Commit
4c073b09
authored
Sep 06, 2022
by
vfdev
Committed by
GitHub
Sep 06, 2022
Browse files
[proto] Fixed bug in ScaleJitter with params (#6541)
* [proto] Fixed bug in ScaleJitter with params * Updated tests
parent
74feb198
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
11 deletions
+18
-11
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+14
-8
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+4
-3
No files found.
test/test_prototype_transforms.py
View file @
4c073b09
...
...
@@ -1263,18 +1263,24 @@ class TestScaleJitter:
scale_range
=
(
0.5
,
1.5
)
transform
=
transforms
.
ScaleJitter
(
target_size
=
target_size
,
scale_range
=
scale_range
)
sample
=
mocker
.
MagicMock
(
spec
=
features
.
Image
,
num_channels
=
3
,
image_size
=
image_size
)
params
=
transform
.
_get_params
(
sample
)
assert
"size"
in
params
size
=
params
[
"size"
]
n_samples
=
5
for
_
in
range
(
n_samples
):
assert
isinstance
(
size
,
tuple
)
and
len
(
size
)
==
2
height
,
width
=
size
params
=
transform
.
_get_params
(
sample
)
assert
"size"
in
params
size
=
params
[
"size"
]
assert
isinstance
(
size
,
tuple
)
and
len
(
size
)
==
2
height
,
width
=
size
r_min
=
min
(
target_size
[
1
]
/
image_size
[
0
],
target_size
[
0
]
/
image_size
[
1
])
*
scale_range
[
0
]
r_max
=
min
(
target_size
[
1
]
/
image_size
[
0
],
target_size
[
0
]
/
image_size
[
1
])
*
scale_range
[
1
]
assert
int
(
tar
ge
t
_size
[
0
]
*
scale_range
[
0
]
)
<=
height
<=
int
(
tar
ge
t
_size
[
0
]
*
scale_range
[
1
]
)
assert
int
(
tar
ge
t
_size
[
1
]
*
scale_range
[
0
]
)
<=
width
<=
int
(
tar
ge
t
_size
[
1
]
*
scale_range
[
1
]
)
assert
int
(
ima
ge_size
[
0
]
*
r_min
)
<=
height
<=
int
(
ima
ge_size
[
0
]
*
r_max
)
assert
int
(
ima
ge_size
[
1
]
*
r_min
)
<=
width
<=
int
(
ima
ge_size
[
1
]
*
r_max
)
def
test__transform
(
self
,
mocker
):
interpolation_sentinel
=
mocker
.
MagicMock
()
...
...
torchvision/prototype/transforms/_geometry.py
View file @
4c073b09
...
...
@@ -727,9 +727,10 @@ class ScaleJitter(Transform):
def
_get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
_
,
orig_height
,
orig_width
=
query_chw
(
sample
)
r
=
self
.
scale_range
[
0
]
+
torch
.
rand
(
1
)
*
(
self
.
scale_range
[
1
]
-
self
.
scale_range
[
0
])
new_width
=
int
(
self
.
target_size
[
1
]
*
r
)
new_height
=
int
(
self
.
target_size
[
0
]
*
r
)
scale
=
self
.
scale_range
[
0
]
+
torch
.
rand
(
1
)
*
(
self
.
scale_range
[
1
]
-
self
.
scale_range
[
0
])
r
=
min
(
self
.
target_size
[
1
]
/
orig_height
,
self
.
target_size
[
0
]
/
orig_width
)
*
scale
new_width
=
int
(
orig_width
*
r
)
new_height
=
int
(
orig_height
*
r
)
return
dict
(
size
=
(
new_height
,
new_width
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment