Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
053e7ebd
Unverified
Commit
053e7ebd
authored
Mar 16, 2022
by
Philip Meier
Committed by
GitHub
Mar 16, 2022
Browse files
port Pad to prototype transforms (#5621)
* port Pad to prototype transforms * use literal
parent
00c119c8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
29 deletions
+64
-29
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+1
-0
torchvision/prototype/transforms/__init__.py
torchvision/prototype/transforms/__init__.py
+1
-0
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+62
-29
No files found.
test/test_prototype_transforms.py
View file @
053e7ebd
...
@@ -71,6 +71,7 @@ class TestSmoke:
...
@@ -71,6 +71,7 @@ class TestSmoke:
transforms
.
CenterCrop
([
16
,
16
]),
transforms
.
CenterCrop
([
16
,
16
]),
transforms
.
ConvertImageDtype
(),
transforms
.
ConvertImageDtype
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
Pad
(
5
),
)
)
def
test_common
(
self
,
transform
,
input
):
def
test_common
(
self
,
transform
,
input
):
transform
(
input
)
transform
(
input
)
...
...
torchvision/prototype/transforms/__init__.py
View file @
053e7ebd
...
@@ -15,6 +15,7 @@ from ._geometry import (
...
@@ -15,6 +15,7 @@ from ._geometry import (
TenCrop
,
TenCrop
,
BatchMultiCrop
,
BatchMultiCrop
,
RandomHorizontalFlip
,
RandomHorizontalFlip
,
Pad
,
RandomZoomOut
,
RandomZoomOut
,
)
)
from
._meta
import
ConvertBoundingBoxFormat
,
ConvertImageDtype
,
ConvertImageColorSpace
from
._meta
import
ConvertBoundingBoxFormat
,
ConvertImageDtype
,
ConvertImageColorSpace
...
...
torchvision/prototype/transforms/_geometry.py
View file @
053e7ebd
import
collections.abc
import
collections.abc
import
math
import
math
import
numbers
import
warnings
import
warnings
from
typing
import
Any
,
Dict
,
List
,
Union
,
Sequence
,
Tuple
,
cast
from
typing
import
Any
,
Dict
,
List
,
Union
,
Sequence
,
Tuple
,
cast
...
@@ -9,6 +10,7 @@ from torchvision.prototype import features
...
@@ -9,6 +10,7 @@ from torchvision.prototype import features
from
torchvision.prototype.transforms
import
Transform
,
InterpolationMode
,
functional
as
F
from
torchvision.prototype.transforms
import
Transform
,
InterpolationMode
,
functional
as
F
from
torchvision.transforms.functional
import
pil_to_tensor
from
torchvision.transforms.functional
import
pil_to_tensor
from
torchvision.transforms.transforms
import
_setup_size
,
_interpolation_modes_from_int
from
torchvision.transforms.transforms
import
_setup_size
,
_interpolation_modes_from_int
from
typing_extensions
import
Literal
from
._utils
import
query_image
,
get_image_dimensions
,
has_any
,
is_simple_tensor
from
._utils
import
query_image
,
get_image_dimensions
,
has_any
,
is_simple_tensor
...
@@ -272,42 +274,31 @@ class BatchMultiCrop(Transform):
...
@@ -272,42 +274,31 @@ class BatchMultiCrop(Transform):
return
apply_recursively
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
return
apply_recursively
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
class
RandomZoomOut
(
Transform
):
class
Pad
(
Transform
):
def
__init__
(
def
__init__
(
self
,
fill
:
Union
[
float
,
Sequence
[
float
]]
=
0.0
,
side_range
:
Tuple
[
float
,
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
self
,
padding
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
float
,
Sequence
[
float
]]
=
0.0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
not
isinstance
(
padding
,
(
numbers
.
Number
,
tuple
,
list
)):
raise
TypeError
(
"Got inappropriate padding arg"
)
if
fill
is
None
:
if
not
isinstance
(
fill
,
(
numbers
.
Number
,
str
,
tuple
,
list
)):
fill
=
0.0
raise
TypeError
(
"Got inappropriate fill arg"
)
self
.
fill
=
fill
self
.
side_range
=
side_range
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
raise
ValueError
(
f
"Invalid canvas side range provided
{
side_range
}
."
)
self
.
p
=
p
def
_get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
image
=
query_image
(
sample
)
orig_c
,
orig_h
,
orig_w
=
get_image_dimensions
(
image
)
r
=
self
.
side_range
[
0
]
+
torch
.
rand
(
1
)
*
(
self
.
side_range
[
1
]
-
self
.
side_range
[
0
])
canvas_width
=
int
(
orig_w
*
r
)
canvas_height
=
int
(
orig_h
*
r
)
r
=
torch
.
rand
(
2
)
if
padding_mode
not
in
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]:
left
=
int
((
canvas_width
-
orig_w
)
*
r
[
0
])
raise
ValueError
(
"Padding mode should be either constant, edge, reflect or symmetric"
)
top
=
int
((
canvas_height
-
orig_h
)
*
r
[
1
])
right
=
canvas_width
-
(
left
+
orig_w
)
bottom
=
canvas_height
-
(
top
+
orig_h
)
padding
=
[
left
,
top
,
right
,
bottom
]
fill
=
self
.
fill
if
isinstance
(
padding
,
Sequence
)
and
len
(
padding
)
not
in
[
1
,
2
,
4
]:
if
not
isinstance
(
fill
,
collections
.
abc
.
Sequence
):
raise
ValueError
(
fill
=
[
fill
]
*
orig_c
f
"Padding must be an int or a 1, 2, or 4 element tuple, not a
{
len
(
padding
)
}
element tuple"
)
return
dict
(
padding
=
padding
,
fill
=
fill
)
self
.
padding
=
padding
self
.
fill
=
fill
self
.
padding_mode
=
padding_mode
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
input
,
features
.
Image
)
or
is_simple_tensor
(
input
):
if
isinstance
(
input
,
features
.
Image
)
or
is_simple_tensor
(
input
):
...
@@ -349,6 +340,48 @@ class RandomZoomOut(Transform):
...
@@ -349,6 +340,48 @@ class RandomZoomOut(Transform):
else
:
else
:
return
input
return
input
class
RandomZoomOut
(
Transform
):
def
__init__
(
self
,
fill
:
Union
[
float
,
Sequence
[
float
]]
=
0.0
,
side_range
:
Tuple
[
float
,
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
if
fill
is
None
:
fill
=
0.0
self
.
fill
=
fill
self
.
side_range
=
side_range
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
raise
ValueError
(
f
"Invalid canvas side range provided
{
side_range
}
."
)
self
.
p
=
p
def
_get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
image
=
query_image
(
sample
)
orig_c
,
orig_h
,
orig_w
=
get_image_dimensions
(
image
)
r
=
self
.
side_range
[
0
]
+
torch
.
rand
(
1
)
*
(
self
.
side_range
[
1
]
-
self
.
side_range
[
0
])
canvas_width
=
int
(
orig_w
*
r
)
canvas_height
=
int
(
orig_h
*
r
)
r
=
torch
.
rand
(
2
)
left
=
int
((
canvas_width
-
orig_w
)
*
r
[
0
])
top
=
int
((
canvas_height
-
orig_h
)
*
r
[
1
])
right
=
canvas_width
-
(
left
+
orig_w
)
bottom
=
canvas_height
-
(
top
+
orig_h
)
padding
=
[
left
,
top
,
right
,
bottom
]
fill
=
self
.
fill
if
not
isinstance
(
fill
,
collections
.
abc
.
Sequence
):
fill
=
[
fill
]
*
orig_c
return
dict
(
padding
=
padding
,
fill
=
fill
)
def
_transform
(
self
,
input
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
transform
=
Pad
(
**
params
,
padding_mode
=
"constant"
)
return
transform
(
input
)
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
torch
.
rand
(
1
)
>=
self
.
p
:
if
torch
.
rand
(
1
)
>=
self
.
p
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment