Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d75a5241
"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "403dba626435fc23bc9329e6a819d8f38a253403"
Unverified
Commit
d75a5241
authored
Feb 09, 2023
by
Philip Meier
Committed by
GitHub
Feb 09, 2023
Browse files
allow nn.ModuleList in RandomApply (#7197)
parent
539c6e29
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
14 deletions
+34
-14
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+15
-9
torchvision/prototype/transforms/_container.py
torchvision/prototype/transforms/_container.py
+19
-5
No files found.
test/test_prototype_transforms_consistency.py
View file @
d75a5241
...
@@ -23,6 +23,7 @@ from prototype_common_utils import (
...
@@ -23,6 +23,7 @@ from prototype_common_utils import (
make_label
,
make_label
,
make_segmentation_mask
,
make_segmentation_mask
,
)
)
from
torch
import
nn
from
torchvision
import
transforms
as
legacy_transforms
from
torchvision
import
transforms
as
legacy_transforms
from
torchvision._utils
import
sequence_to_str
from
torchvision._utils
import
sequence_to_str
from
torchvision.prototype
import
datapoints
,
transforms
as
prototype_transforms
from
torchvision.prototype
import
datapoints
,
transforms
as
prototype_transforms
...
@@ -761,19 +762,24 @@ class TestContainerTransforms:
...
@@ -761,19 +762,24 @@ class TestContainerTransforms:
check_call_consistency
(
prototype_transform
,
legacy_transform
)
check_call_consistency
(
prototype_transform
,
legacy_transform
)
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0
,
0.1
,
0.5
,
0.9
,
1
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0
,
0.1
,
0.5
,
0.9
,
1
])
def
test_random_apply
(
self
,
p
):
@
pytest
.
mark
.
parametrize
(
"sequence_type"
,
[
list
,
nn
.
ModuleList
])
def
test_random_apply
(
self
,
p
,
sequence_type
):
prototype_transform
=
prototype_transforms
.
RandomApply
(
prototype_transform
=
prototype_transforms
.
RandomApply
(
[
sequence_type
(
prototype_transforms
.
Resize
(
256
),
[
prototype_transforms
.
CenterCrop
(
224
),
prototype_transforms
.
Resize
(
256
),
],
prototype_transforms
.
CenterCrop
(
224
),
]
),
p
=
p
,
p
=
p
,
)
)
legacy_transform
=
legacy_transforms
.
RandomApply
(
legacy_transform
=
legacy_transforms
.
RandomApply
(
[
sequence_type
(
legacy_transforms
.
Resize
(
256
),
[
legacy_transforms
.
CenterCrop
(
224
),
legacy_transforms
.
Resize
(
256
),
],
legacy_transforms
.
CenterCrop
(
224
),
]
),
p
=
p
,
p
=
p
,
)
)
...
...
torchvision/prototype/transforms/_container.py
View file @
d75a5241
import
warnings
import
warnings
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Union
import
torch
import
torch
from
torch
import
nn
from
torchvision.prototype.transforms
import
Transform
from
torchvision.prototype.transforms
import
Transform
...
@@ -25,9 +27,13 @@ class Compose(Transform):
...
@@ -25,9 +27,13 @@ class Compose(Transform):
return
"
\n
"
.
join
(
format_string
)
return
"
\n
"
.
join
(
format_string
)
class
RandomApply
(
Compose
):
class
RandomApply
(
Transform
):
def
__init__
(
self
,
transforms
:
Sequence
[
Callable
],
p
:
float
=
0.5
)
->
None
:
def
__init__
(
self
,
transforms
:
Union
[
Sequence
[
Callable
],
nn
.
ModuleList
],
p
:
float
=
0.5
)
->
None
:
super
().
__init__
(
transforms
)
super
().
__init__
()
if
not
isinstance
(
transforms
,
(
Sequence
,
nn
.
ModuleList
)):
raise
TypeError
(
"Argument transforms should be a sequence of callables or a `nn.ModuleList`"
)
self
.
transforms
=
transforms
if
not
(
0.0
<=
p
<=
1.0
):
if
not
(
0.0
<=
p
<=
1.0
):
raise
ValueError
(
"`p` should be a floating point value in the interval [0.0, 1.0]."
)
raise
ValueError
(
"`p` should be a floating point value in the interval [0.0, 1.0]."
)
...
@@ -39,7 +45,15 @@ class RandomApply(Compose):
...
@@ -39,7 +45,15 @@ class RandomApply(Compose):
if
torch
.
rand
(
1
)
>=
self
.
p
:
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
sample
return
super
().
forward
(
sample
)
for
transform
in
self
.
transforms
:
sample
=
transform
(
sample
)
return
sample
def
extra_repr
(
self
)
->
str
:
format_string
=
[]
for
t
in
self
.
transforms
:
format_string
.
append
(
f
"
{
t
}
"
)
return
"
\n
"
.
join
(
format_string
)
class
RandomChoice
(
Transform
):
class
RandomChoice
(
Transform
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment