Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
297e2b87
"tests/python/vscode:/vscode.git/clone" did not exist on "b1dd592da748f8a604391058679032184bbfa2dd"
Unverified
Commit
297e2b87
authored
Mar 18, 2022
by
Philip Meier
Committed by
GitHub
Mar 18, 2022
Browse files
add random apply transform base class (#5639)
parent
39772ece
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
53 deletions
+33
-53
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+3
-7
torchvision/prototype/transforms/_container.py
torchvision/prototype/transforms/_container.py
+7
-11
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+5
-35
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+18
-0
No files found.
torchvision/prototype/transforms/_augment.py
View file @
297e2b87
...
@@ -7,10 +7,11 @@ import torch
...
@@ -7,10 +7,11 @@ import torch
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms
import
Transform
,
functional
as
F
from
torchvision.prototype.transforms
import
Transform
,
functional
as
F
from
._transform
import
_RandomApplyTransform
from
._utils
import
query_image
,
get_image_dimensions
,
has_all
,
has_any
,
is_simple_tensor
from
._utils
import
query_image
,
get_image_dimensions
,
has_all
,
has_any
,
is_simple_tensor
class
RandomErasing
(
Transform
):
class
RandomErasing
(
_RandomApply
Transform
):
def
__init__
(
def
__init__
(
self
,
self
,
p
:
float
=
0.5
,
p
:
float
=
0.5
,
...
@@ -18,7 +19,7 @@ class RandomErasing(Transform):
...
@@ -18,7 +19,7 @@ class RandomErasing(Transform):
ratio
:
Tuple
[
float
,
float
]
=
(
0.3
,
3.3
),
ratio
:
Tuple
[
float
,
float
]
=
(
0.3
,
3.3
),
value
:
float
=
0
,
value
:
float
=
0
,
):
):
super
().
__init__
()
super
().
__init__
(
p
=
p
)
if
not
isinstance
(
value
,
(
numbers
.
Number
,
str
,
tuple
,
list
)):
if
not
isinstance
(
value
,
(
numbers
.
Number
,
str
,
tuple
,
list
)):
raise
TypeError
(
"Argument value should be either a number or str or a sequence"
)
raise
TypeError
(
"Argument value should be either a number or str or a sequence"
)
if
isinstance
(
value
,
str
)
and
value
!=
"random"
:
if
isinstance
(
value
,
str
)
and
value
!=
"random"
:
...
@@ -31,9 +32,6 @@ class RandomErasing(Transform):
...
@@ -31,9 +32,6 @@ class RandomErasing(Transform):
warnings
.
warn
(
"Scale and ratio should be of kind (min, max)"
)
warnings
.
warn
(
"Scale and ratio should be of kind (min, max)"
)
if
scale
[
0
]
<
0
or
scale
[
1
]
>
1
:
if
scale
[
0
]
<
0
or
scale
[
1
]
>
1
:
raise
ValueError
(
"Scale should be between 0 and 1"
)
raise
ValueError
(
"Scale should be between 0 and 1"
)
if
p
<
0
or
p
>
1
:
raise
ValueError
(
"Random erasing probability should be between 0 and 1"
)
self
.
p
=
p
self
.
scale
=
scale
self
.
scale
=
scale
self
.
ratio
=
ratio
self
.
ratio
=
ratio
self
.
value
=
value
self
.
value
=
value
...
@@ -99,8 +97,6 @@ class RandomErasing(Transform):
...
@@ -99,8 +97,6 @@ class RandomErasing(Transform):
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
has_any
(
sample
,
features
.
BoundingBox
,
features
.
SegmentationMask
):
if
has_any
(
sample
,
features
.
BoundingBox
,
features
.
SegmentationMask
):
raise
TypeError
(
f
"BoundingBox'es and SegmentationMask's are not supported by
{
type
(
self
).
__name__
}
()"
)
raise
TypeError
(
f
"BoundingBox'es and SegmentationMask's are not supported by
{
type
(
self
).
__name__
}
()"
)
elif
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
return
super
().
forward
(
sample
)
...
...
torchvision/prototype/transforms/_container.py
View file @
297e2b87
from
typing
import
Any
,
Optional
,
List
from
typing
import
Any
,
Optional
,
List
,
Dict
import
torch
import
torch
from
torchvision.prototype.transforms
import
Transform
from
._transform
import
Transform
from
._transform
import
_RandomApply
Transform
class
Compose
(
Transform
):
class
Compose
(
Transform
):
...
@@ -19,18 +20,13 @@ class Compose(Transform):
...
@@ -19,18 +20,13 @@ class Compose(Transform):
return
sample
return
sample
class
RandomApply
(
Transform
):
class
RandomApply
(
_RandomApply
Transform
):
def
__init__
(
self
,
transform
:
Transform
,
*
,
p
:
float
=
0.5
)
->
None
:
def
__init__
(
self
,
transform
:
Transform
,
*
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
super
().
__init__
(
p
=
p
)
self
.
transform
=
transform
self
.
transform
=
transform
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
float
(
torch
.
rand
(()))
<
self
.
p
:
return
sample
return
self
.
transform
(
sample
)
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
self
.
transform
(
input
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
return
f
"p=
{
self
.
p
}
"
return
f
"p=
{
self
.
p
}
"
...
...
torchvision/prototype/transforms/_geometry.py
View file @
297e2b87
...
@@ -12,21 +12,11 @@ from torchvision.transforms.functional import pil_to_tensor
...
@@ -12,21 +12,11 @@ from torchvision.transforms.functional import pil_to_tensor
from
torchvision.transforms.transforms
import
_setup_size
,
_interpolation_modes_from_int
from
torchvision.transforms.transforms
import
_setup_size
,
_interpolation_modes_from_int
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
from
._transform
import
_RandomApplyTransform
from
._utils
import
query_image
,
get_image_dimensions
,
has_any
,
is_simple_tensor
from
._utils
import
query_image
,
get_image_dimensions
,
has_any
,
is_simple_tensor
class
RandomHorizontalFlip
(
Transform
):
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
def
__init__
(
self
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
input
,
features
.
Image
):
if
isinstance
(
input
,
features
.
Image
):
output
=
F
.
horizontal_flip_image_tensor
(
input
)
output
=
F
.
horizontal_flip_image_tensor
(
input
)
...
@@ -45,18 +35,7 @@ class RandomHorizontalFlip(Transform):
...
@@ -45,18 +35,7 @@ class RandomHorizontalFlip(Transform):
return
input
return
input
class
RandomVerticalFlip
(
Transform
):
class
RandomVerticalFlip
(
_RandomApplyTransform
):
def
__init__
(
self
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
input
,
features
.
Image
):
if
isinstance
(
input
,
features
.
Image
):
output
=
F
.
vertical_flip_image_tensor
(
input
)
output
=
F
.
vertical_flip_image_tensor
(
input
)
...
@@ -371,11 +350,11 @@ class Pad(Transform):
...
@@ -371,11 +350,11 @@ class Pad(Transform):
return
input
return
input
class
RandomZoomOut
(
Transform
):
class
RandomZoomOut
(
_RandomApply
Transform
):
def
__init__
(
def
__init__
(
self
,
fill
:
Union
[
float
,
Sequence
[
float
]]
=
0.0
,
side_range
:
Tuple
[
float
,
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
self
,
fill
:
Union
[
float
,
Sequence
[
float
]]
=
0.0
,
side_range
:
Tuple
[
float
,
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
(
p
=
p
)
if
fill
is
None
:
if
fill
is
None
:
fill
=
0.0
fill
=
0.0
...
@@ -385,8 +364,6 @@ class RandomZoomOut(Transform):
...
@@ -385,8 +364,6 @@ class RandomZoomOut(Transform):
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
raise
ValueError
(
f
"Invalid canvas side range provided
{
side_range
}
."
)
raise
ValueError
(
f
"Invalid canvas side range provided
{
side_range
}
."
)
self
.
p
=
p
def
_get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
image
=
query_image
(
sample
)
image
=
query_image
(
sample
)
orig_c
,
orig_h
,
orig_w
=
get_image_dimensions
(
image
)
orig_c
,
orig_h
,
orig_w
=
get_image_dimensions
(
image
)
...
@@ -411,10 +388,3 @@ class RandomZoomOut(Transform):
...
@@ -411,10 +388,3 @@ class RandomZoomOut(Transform):
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
transform
=
Pad
(
**
params
,
padding_mode
=
"constant"
)
transform
=
Pad
(
**
params
,
padding_mode
=
"constant"
)
return
transform
(
input
)
return
transform
(
input
)
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
torchvision/prototype/transforms/_transform.py
View file @
297e2b87
...
@@ -2,6 +2,7 @@ import enum
...
@@ -2,6 +2,7 @@ import enum
import
functools
import
functools
from
typing
import
Any
,
Dict
from
typing
import
Any
,
Dict
import
torch
from
torch
import
nn
from
torch
import
nn
from
torchvision.prototype.utils._internal
import
apply_recursively
from
torchvision.prototype.utils._internal
import
apply_recursively
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
...
@@ -34,3 +35,20 @@ class Transform(nn.Module):
...
@@ -34,3 +35,20 @@ class Transform(nn.Module):
extra
.
append
(
f
"
{
name
}
=
{
value
}
"
)
extra
.
append
(
f
"
{
name
}
=
{
value
}
"
)
return
", "
.
join
(
extra
)
return
", "
.
join
(
extra
)
class
_RandomApplyTransform
(
Transform
):
def
__init__
(
self
,
*
,
p
:
float
=
0.5
)
->
None
:
if
not
(
0.0
<=
p
<=
1.0
):
raise
ValueError
(
"`p` should be a floating point value in the interval [0.0, 1.0]."
)
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
super
().
forward
(
sample
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment