Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
297e2b87
Unverified
Commit
297e2b87
authored
Mar 18, 2022
by
Philip Meier
Committed by
GitHub
Mar 18, 2022
Browse files
add random apply transform base class (#5639)
parent
39772ece
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
53 deletions
+33
-53
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+3
-7
torchvision/prototype/transforms/_container.py
torchvision/prototype/transforms/_container.py
+7
-11
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+5
-35
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+18
-0
No files found.
torchvision/prototype/transforms/_augment.py
View file @
297e2b87
...
@@ -7,10 +7,11 @@ import torch
...
@@ -7,10 +7,11 @@ import torch
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms
import
Transform
,
functional
as
F
from
torchvision.prototype.transforms
import
Transform
,
functional
as
F
from
._transform
import
_RandomApplyTransform
from
._utils
import
query_image
,
get_image_dimensions
,
has_all
,
has_any
,
is_simple_tensor
from
._utils
import
query_image
,
get_image_dimensions
,
has_all
,
has_any
,
is_simple_tensor
class
RandomErasing
(
Transform
):
class
RandomErasing
(
_RandomApply
Transform
):
def
__init__
(
def
__init__
(
self
,
self
,
p
:
float
=
0.5
,
p
:
float
=
0.5
,
...
@@ -18,7 +19,7 @@ class RandomErasing(Transform):
...
@@ -18,7 +19,7 @@ class RandomErasing(Transform):
ratio
:
Tuple
[
float
,
float
]
=
(
0.3
,
3.3
),
ratio
:
Tuple
[
float
,
float
]
=
(
0.3
,
3.3
),
value
:
float
=
0
,
value
:
float
=
0
,
):
):
super
().
__init__
()
super
().
__init__
(
p
=
p
)
if
not
isinstance
(
value
,
(
numbers
.
Number
,
str
,
tuple
,
list
)):
if
not
isinstance
(
value
,
(
numbers
.
Number
,
str
,
tuple
,
list
)):
raise
TypeError
(
"Argument value should be either a number or str or a sequence"
)
raise
TypeError
(
"Argument value should be either a number or str or a sequence"
)
if
isinstance
(
value
,
str
)
and
value
!=
"random"
:
if
isinstance
(
value
,
str
)
and
value
!=
"random"
:
...
@@ -31,9 +32,6 @@ class RandomErasing(Transform):
...
@@ -31,9 +32,6 @@ class RandomErasing(Transform):
warnings
.
warn
(
"Scale and ratio should be of kind (min, max)"
)
warnings
.
warn
(
"Scale and ratio should be of kind (min, max)"
)
if
scale
[
0
]
<
0
or
scale
[
1
]
>
1
:
if
scale
[
0
]
<
0
or
scale
[
1
]
>
1
:
raise
ValueError
(
"Scale should be between 0 and 1"
)
raise
ValueError
(
"Scale should be between 0 and 1"
)
if
p
<
0
or
p
>
1
:
raise
ValueError
(
"Random erasing probability should be between 0 and 1"
)
self
.
p
=
p
self
.
scale
=
scale
self
.
scale
=
scale
self
.
ratio
=
ratio
self
.
ratio
=
ratio
self
.
value
=
value
self
.
value
=
value
...
@@ -99,8 +97,6 @@ class RandomErasing(Transform):
...
@@ -99,8 +97,6 @@ class RandomErasing(Transform):
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
has_any
(
sample
,
features
.
BoundingBox
,
features
.
SegmentationMask
):
if
has_any
(
sample
,
features
.
BoundingBox
,
features
.
SegmentationMask
):
raise
TypeError
(
f
"BoundingBox'es and SegmentationMask's are not supported by
{
type
(
self
).
__name__
}
()"
)
raise
TypeError
(
f
"BoundingBox'es and SegmentationMask's are not supported by
{
type
(
self
).
__name__
}
()"
)
elif
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
return
super
().
forward
(
sample
)
...
...
torchvision/prototype/transforms/_container.py
View file @
297e2b87
from
typing
import
Any
,
Optional
,
List
from
typing
import
Any
,
Optional
,
List
,
Dict
import
torch
import
torch
from
torchvision.prototype.transforms
import
Transform
from
._transform
import
Transform
from
._transform
import
_RandomApply
Transform
class
Compose
(
Transform
):
class
Compose
(
Transform
):
...
@@ -19,18 +20,13 @@ class Compose(Transform):
...
@@ -19,18 +20,13 @@ class Compose(Transform):
return
sample
return
sample
class
RandomApply
(
Transform
):
class
RandomApply
(
_RandomApply
Transform
):
def
__init__
(
self
,
transform
:
Transform
,
*
,
p
:
float
=
0.5
)
->
None
:
def
__init__
(
self
,
transform
:
Transform
,
*
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
super
().
__init__
(
p
=
p
)
self
.
transform
=
transform
self
.
transform
=
transform
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
float
(
torch
.
rand
(()))
<
self
.
p
:
return
sample
return
self
.
transform
(
sample
)
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
self
.
transform
(
input
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
return
f
"p=
{
self
.
p
}
"
return
f
"p=
{
self
.
p
}
"
...
...
torchvision/prototype/transforms/_geometry.py
View file @
297e2b87
...
@@ -12,21 +12,11 @@ from torchvision.transforms.functional import pil_to_tensor
...
@@ -12,21 +12,11 @@ from torchvision.transforms.functional import pil_to_tensor
from
torchvision.transforms.transforms
import
_setup_size
,
_interpolation_modes_from_int
from
torchvision.transforms.transforms
import
_setup_size
,
_interpolation_modes_from_int
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
from
._transform
import
_RandomApplyTransform
from
._utils
import
query_image
,
get_image_dimensions
,
has_any
,
is_simple_tensor
from
._utils
import
query_image
,
get_image_dimensions
,
has_any
,
is_simple_tensor
class
RandomHorizontalFlip
(
Transform
):
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
def
__init__
(
self
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
input
,
features
.
Image
):
if
isinstance
(
input
,
features
.
Image
):
output
=
F
.
horizontal_flip_image_tensor
(
input
)
output
=
F
.
horizontal_flip_image_tensor
(
input
)
...
@@ -45,18 +35,7 @@ class RandomHorizontalFlip(Transform):
...
@@ -45,18 +35,7 @@ class RandomHorizontalFlip(Transform):
return
input
return
input
class
RandomVerticalFlip
(
Transform
):
class
RandomVerticalFlip
(
_RandomApplyTransform
):
def
__init__
(
self
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
input
,
features
.
Image
):
if
isinstance
(
input
,
features
.
Image
):
output
=
F
.
vertical_flip_image_tensor
(
input
)
output
=
F
.
vertical_flip_image_tensor
(
input
)
...
@@ -371,11 +350,11 @@ class Pad(Transform):
...
@@ -371,11 +350,11 @@ class Pad(Transform):
return
input
return
input
class
RandomZoomOut
(
Transform
):
class
RandomZoomOut
(
_RandomApply
Transform
):
def
__init__
(
def
__init__
(
self
,
fill
:
Union
[
float
,
Sequence
[
float
]]
=
0.0
,
side_range
:
Tuple
[
float
,
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
self
,
fill
:
Union
[
float
,
Sequence
[
float
]]
=
0.0
,
side_range
:
Tuple
[
float
,
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
(
p
=
p
)
if
fill
is
None
:
if
fill
is
None
:
fill
=
0.0
fill
=
0.0
...
@@ -385,8 +364,6 @@ class RandomZoomOut(Transform):
...
@@ -385,8 +364,6 @@ class RandomZoomOut(Transform):
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
raise
ValueError
(
f
"Invalid canvas side range provided
{
side_range
}
."
)
raise
ValueError
(
f
"Invalid canvas side range provided
{
side_range
}
."
)
self
.
p
=
p
def
_get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
image
=
query_image
(
sample
)
image
=
query_image
(
sample
)
orig_c
,
orig_h
,
orig_w
=
get_image_dimensions
(
image
)
orig_c
,
orig_h
,
orig_w
=
get_image_dimensions
(
image
)
...
@@ -411,10 +388,3 @@ class RandomZoomOut(Transform):
...
@@ -411,10 +388,3 @@ class RandomZoomOut(Transform):
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
transform
=
Pad
(
**
params
,
padding_mode
=
"constant"
)
transform
=
Pad
(
**
params
,
padding_mode
=
"constant"
)
return
transform
(
input
)
return
transform
(
input
)
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
torchvision/prototype/transforms/_transform.py
View file @
297e2b87
...
@@ -2,6 +2,7 @@ import enum
...
@@ -2,6 +2,7 @@ import enum
import
functools
import
functools
from
typing
import
Any
,
Dict
from
typing
import
Any
,
Dict
import
torch
from
torch
import
nn
from
torch
import
nn
from
torchvision.prototype.utils._internal
import
apply_recursively
from
torchvision.prototype.utils._internal
import
apply_recursively
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
...
@@ -34,3 +35,20 @@ class Transform(nn.Module):
...
@@ -34,3 +35,20 @@ class Transform(nn.Module):
extra
.
append
(
f
"
{
name
}
=
{
value
}
"
)
extra
.
append
(
f
"
{
name
}
=
{
value
}
"
)
return
", "
.
join
(
extra
)
return
", "
.
join
(
extra
)
class
_RandomApplyTransform
(
Transform
):
def
__init__
(
self
,
*
,
p
:
float
=
0.5
)
->
None
:
if
not
(
0.0
<=
p
<=
1.0
):
raise
ValueError
(
"`p` should be a floating point value in the interval [0.0, 1.0]."
)
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment