Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f9966d22
Unverified
Commit
f9966d22
authored
Aug 25, 2022
by
vfdev
Committed by
GitHub
Aug 25, 2022
Browse files
[proto] Restored BC for RandomChoice and RandomOrder (#6488)
parent
020eafe1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
17 deletions
+14
-17
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+8
-10
torchvision/prototype/transforms/_container.py
torchvision/prototype/transforms/_container.py
+6
-7
No files found.
test/test_prototype_transforms.py
View file @
f9966d22
...
@@ -1092,20 +1092,18 @@ class TestToTensor:
...
@@ -1092,20 +1092,18 @@ class TestToTensor:
fn
.
assert_called_once_with
(
inpt
)
fn
.
assert_called_once_with
(
inpt
)
class
TestCompose
:
class
TestContainers
:
def
test_assertions
(
self
):
@
pytest
.
mark
.
parametrize
(
"transform_cls"
,
[
transforms
.
Compose
,
transforms
.
RandomChoice
,
transforms
.
RandomOrder
])
def
test_assertions
(
self
,
transform_cls
):
with
pytest
.
raises
(
TypeError
,
match
=
"Argument transforms should be a sequence of callables"
):
with
pytest
.
raises
(
TypeError
,
match
=
"Argument transforms should be a sequence of callables"
):
transform
s
.
Compose
(
123
)
transform
_cls
(
transforms
.
RandomCrop
(
28
)
)
@
pytest
.
mark
.
parametrize
(
"transform_cls"
,
[
transforms
.
Compose
,
transforms
.
RandomChoice
,
transforms
.
RandomOrder
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"trfms"
,
"trfms"
,
[[
transforms
.
Pad
(
2
),
transforms
.
RandomCrop
(
28
)],
[
lambda
x
:
2.0
*
x
,
transforms
.
RandomCrop
(
28
)]]
[
[
transforms
.
Pad
(
2
),
transforms
.
RandomCrop
(
28
)],
[
lambda
x
:
2.0
*
x
],
],
)
)
def
test_ctor
(
self
,
trfms
):
def
test_ctor
(
self
,
transform_cls
,
trfms
):
c
=
transform
s
.
Compose
(
trfms
)
c
=
transform
_cls
(
trfms
)
inpt
=
torch
.
rand
(
1
,
3
,
32
,
32
)
inpt
=
torch
.
rand
(
1
,
3
,
32
,
32
)
output
=
c
(
inpt
)
output
=
c
(
inpt
)
assert
isinstance
(
output
,
torch
.
Tensor
)
assert
isinstance
(
output
,
torch
.
Tensor
)
...
...
torchvision/prototype/transforms/_container.py
View file @
f9966d22
...
@@ -33,7 +33,9 @@ class RandomApply(_RandomApplyTransform):
...
@@ -33,7 +33,9 @@ class RandomApply(_RandomApplyTransform):
class
RandomChoice
(
Transform
):
class
RandomChoice
(
Transform
):
def
__init__
(
self
,
*
transforms
:
Transform
,
probabilities
:
Optional
[
List
[
float
]]
=
None
)
->
None
:
def
__init__
(
self
,
transforms
:
Sequence
[
Callable
],
probabilities
:
Optional
[
List
[
float
]]
=
None
)
->
None
:
if
not
isinstance
(
transforms
,
Sequence
):
raise
TypeError
(
"Argument transforms should be a sequence of callables"
)
if
probabilities
is
None
:
if
probabilities
is
None
:
probabilities
=
[
1
]
*
len
(
transforms
)
probabilities
=
[
1
]
*
len
(
transforms
)
elif
len
(
probabilities
)
!=
len
(
transforms
):
elif
len
(
probabilities
)
!=
len
(
transforms
):
...
@@ -45,9 +47,6 @@ class RandomChoice(Transform):
...
@@ -45,9 +47,6 @@ class RandomChoice(Transform):
super
().
__init__
()
super
().
__init__
()
self
.
transforms
=
transforms
self
.
transforms
=
transforms
for
idx
,
transform
in
enumerate
(
transforms
):
self
.
add_module
(
str
(
idx
),
transform
)
total
=
sum
(
probabilities
)
total
=
sum
(
probabilities
)
self
.
probabilities
=
[
p
/
total
for
p
in
probabilities
]
self
.
probabilities
=
[
p
/
total
for
p
in
probabilities
]
...
@@ -58,11 +57,11 @@ class RandomChoice(Transform):
...
@@ -58,11 +57,11 @@ class RandomChoice(Transform):
class
RandomOrder
(
Transform
):
class
RandomOrder
(
Transform
):
def
__init__
(
self
,
*
transforms
:
Transform
)
->
None
:
def
__init__
(
self
,
transforms
:
Sequence
[
Callable
])
->
None
:
if
not
isinstance
(
transforms
,
Sequence
):
raise
TypeError
(
"Argument transforms should be a sequence of callables"
)
super
().
__init__
()
super
().
__init__
()
self
.
transforms
=
transforms
self
.
transforms
=
transforms
for
idx
,
transform
in
enumerate
(
transforms
):
self
.
add_module
(
str
(
idx
),
transform
)
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
for
idx
in
torch
.
randperm
(
len
(
self
.
transforms
)):
for
idx
in
torch
.
randperm
(
len
(
self
.
transforms
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment