Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
297e2b87
Unverified
Commit
297e2b87
authored
Mar 18, 2022
by
Philip Meier
Committed by
GitHub
Mar 18, 2022
Browse files
add random apply transform base class (#5639)
parent
39772ece
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
53 deletions
+33
-53
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+3
-7
torchvision/prototype/transforms/_container.py
torchvision/prototype/transforms/_container.py
+7
-11
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+5
-35
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+18
-0
No files found.
torchvision/prototype/transforms/_augment.py
View file @
297e2b87
...
...
@@ -7,10 +7,11 @@ import torch
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms
import
Transform
,
functional
as
F
from
._transform
import
_RandomApplyTransform
from
._utils
import
query_image
,
get_image_dimensions
,
has_all
,
has_any
,
is_simple_tensor
class
RandomErasing
(
Transform
):
class
RandomErasing
(
_RandomApply
Transform
):
def
__init__
(
self
,
p
:
float
=
0.5
,
...
...
@@ -18,7 +19,7 @@ class RandomErasing(Transform):
ratio
:
Tuple
[
float
,
float
]
=
(
0.3
,
3.3
),
value
:
float
=
0
,
):
super
().
__init__
()
super
().
__init__
(
p
=
p
)
if
not
isinstance
(
value
,
(
numbers
.
Number
,
str
,
tuple
,
list
)):
raise
TypeError
(
"Argument value should be either a number or str or a sequence"
)
if
isinstance
(
value
,
str
)
and
value
!=
"random"
:
...
...
@@ -31,9 +32,6 @@ class RandomErasing(Transform):
warnings
.
warn
(
"Scale and ratio should be of kind (min, max)"
)
if
scale
[
0
]
<
0
or
scale
[
1
]
>
1
:
raise
ValueError
(
"Scale should be between 0 and 1"
)
if
p
<
0
or
p
>
1
:
raise
ValueError
(
"Random erasing probability should be between 0 and 1"
)
self
.
p
=
p
self
.
scale
=
scale
self
.
ratio
=
ratio
self
.
value
=
value
...
...
@@ -99,8 +97,6 @@ class RandomErasing(Transform):
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
has_any
(
sample
,
features
.
BoundingBox
,
features
.
SegmentationMask
):
raise
TypeError
(
f
"BoundingBox'es and SegmentationMask's are not supported by
{
type
(
self
).
__name__
}
()"
)
elif
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
...
...
torchvision/prototype/transforms/_container.py
View file @
297e2b87
from
typing
import
Any
,
Optional
,
List
from
typing
import
Any
,
Optional
,
List
,
Dict
import
torch
from
torchvision.prototype.transforms
import
Transform
from
._transform
import
Transform
from
._transform
import
_RandomApply
Transform
class
Compose
(
Transform
):
...
...
@@ -19,18 +20,13 @@ class Compose(Transform):
return
sample
class
RandomApply
(
Transform
):
class
RandomApply
(
_RandomApply
Transform
):
def
__init__
(
self
,
transform
:
Transform
,
*
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
super
().
__init__
(
p
=
p
)
self
.
transform
=
transform
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
float
(
torch
.
rand
(()))
<
self
.
p
:
return
sample
return
self
.
transform
(
sample
)
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
self
.
transform
(
input
)
def
extra_repr
(
self
)
->
str
:
return
f
"p=
{
self
.
p
}
"
...
...
torchvision/prototype/transforms/_geometry.py
View file @
297e2b87
...
...
@@ -12,21 +12,11 @@ from torchvision.transforms.functional import pil_to_tensor
from
torchvision.transforms.transforms
import
_setup_size
,
_interpolation_modes_from_int
from
typing_extensions
import
Literal
from
._transform
import
_RandomApplyTransform
from
._utils
import
query_image
,
get_image_dimensions
,
has_any
,
is_simple_tensor
class
RandomHorizontalFlip
(
Transform
):
def
__init__
(
self
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
input
,
features
.
Image
):
output
=
F
.
horizontal_flip_image_tensor
(
input
)
...
...
@@ -45,18 +35,7 @@ class RandomHorizontalFlip(Transform):
return
input
class
RandomVerticalFlip
(
Transform
):
def
__init__
(
self
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
class
RandomVerticalFlip
(
_RandomApplyTransform
):
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
input
,
features
.
Image
):
output
=
F
.
vertical_flip_image_tensor
(
input
)
...
...
@@ -371,11 +350,11 @@ class Pad(Transform):
return
input
class
RandomZoomOut
(
Transform
):
class
RandomZoomOut
(
_RandomApply
Transform
):
def
__init__
(
self
,
fill
:
Union
[
float
,
Sequence
[
float
]]
=
0.0
,
side_range
:
Tuple
[
float
,
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
super
().
__init__
(
p
=
p
)
if
fill
is
None
:
fill
=
0.0
...
...
@@ -385,8 +364,6 @@ class RandomZoomOut(Transform):
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
raise
ValueError
(
f
"Invalid canvas side range provided
{
side_range
}
."
)
self
.
p
=
p
def
_get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
image
=
query_image
(
sample
)
orig_c
,
orig_h
,
orig_w
=
get_image_dimensions
(
image
)
...
...
@@ -411,10 +388,3 @@ class RandomZoomOut(Transform):
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
transform
=
Pad
(
**
params
,
padding_mode
=
"constant"
)
return
transform
(
input
)
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
torchvision/prototype/transforms/_transform.py
View file @
297e2b87
...
...
@@ -2,6 +2,7 @@ import enum
import
functools
from
typing
import
Any
,
Dict
import
torch
from
torch
import
nn
from
torchvision.prototype.utils._internal
import
apply_recursively
from
torchvision.utils
import
_log_api_usage_once
...
...
@@ -34,3 +35,20 @@ class Transform(nn.Module):
extra
.
append
(
f
"
{
name
}
=
{
value
}
"
)
return
", "
.
join
(
extra
)
class
_RandomApplyTransform
(
Transform
):
def
__init__
(
self
,
*
,
p
:
float
=
0.5
)
->
None
:
if
not
(
0.0
<=
p
<=
1.0
):
raise
ValueError
(
"`p` should be a floating point value in the interval [0.0, 1.0]."
)
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment