Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
efd6bc06
Unverified
Commit
efd6bc06
authored
Feb 16, 2023
by
Philip Meier
Committed by
GitHub
Feb 16, 2023
Browse files
make fill defaultdict an implementation detail (#7258)
parent
b7892d3a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
38 deletions
+29
-38
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+3
-2
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+23
-20
torchvision/transforms/v2/_transform.py
torchvision/transforms/v2/_transform.py
+3
-16
No files found.
torchvision/prototype/transforms/_geometry.py
View file @
efd6bc06
...
...
@@ -22,7 +22,8 @@ class FixedSizeCrop(Transform):
self
.
crop_height
=
size
[
0
]
self
.
crop_width
=
size
[
1
]
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
self
.
padding_mode
=
padding_mode
...
...
@@ -118,7 +119,7 @@ class FixedSizeCrop(Transform):
)
if
params
[
"needs_pad"
]:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
inpt
=
F
.
pad
(
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
return
inpt
torchvision/transforms/v2/_geometry.py
View file @
efd6bc06
...
...
@@ -255,9 +255,7 @@ class Pad(Transform):
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
f
"
{
type
(
self
.
__name__
)
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
for images."
)
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
."
)
return
params
...
...
@@ -276,11 +274,12 @@ class Pad(Transform):
if
not
isinstance
(
padding
,
int
):
padding
=
list
(
padding
)
self
.
padding
=
padding
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
self
.
padding_mode
=
padding_mode
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
pad
(
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
...
...
@@ -293,7 +292,8 @@ class RandomZoomOut(_RandomApplyTransform):
)
->
None
:
super
().
__init__
(
p
=
p
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
_check_sequence_input
(
side_range
,
"side_range"
,
req_sizes
=
(
2
,))
...
...
@@ -318,7 +318,7 @@ class RandomZoomOut(_RandomApplyTransform):
return
dict
(
padding
=
padding
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
pad
(
inpt
,
**
params
,
fill
=
fill
)
...
...
@@ -338,7 +338,8 @@ class RandomRotation(Transform):
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
expand
=
expand
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
if
center
is
not
None
:
_check_sequence_input
(
center
,
"center"
,
req_sizes
=
(
2
,))
...
...
@@ -350,7 +351,7 @@ class RandomRotation(Transform):
return
dict
(
angle
=
angle
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
rotate
(
inpt
,
**
params
,
...
...
@@ -395,7 +396,8 @@ class RandomAffine(Transform):
self
.
shear
=
shear
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
if
center
is
not
None
:
_check_sequence_input
(
center
,
"center"
,
req_sizes
=
(
2
,))
...
...
@@ -430,7 +432,7 @@ class RandomAffine(Transform):
return
dict
(
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
affine
(
inpt
,
**
params
,
...
...
@@ -447,9 +449,7 @@ class RandomCrop(Transform):
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
f
"
{
type
(
self
.
__name__
)
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
for images."
)
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
."
)
padding
=
self
.
padding
if
padding
is
not
None
:
...
...
@@ -478,7 +478,8 @@ class RandomCrop(Transform):
self
.
padding
=
F
.
_geometry
.
_parse_pad_padding
(
padding
)
if
padding
else
None
# type: ignore[arg-type]
self
.
pad_if_needed
=
pad_if_needed
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
self
.
padding_mode
=
padding_mode
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
...
...
@@ -541,7 +542,7 @@ class RandomCrop(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"needs_pad"
]:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
inpt
=
F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
if
params
[
"needs_crop"
]:
...
...
@@ -567,7 +568,8 @@ class RandomPerspective(_RandomApplyTransform):
self
.
distortion_scale
=
distortion_scale
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
height
,
width
=
query_spatial_size
(
flat_inputs
)
...
...
@@ -600,7 +602,7 @@ class RandomPerspective(_RandomApplyTransform):
return
dict
(
coefficients
=
perspective_coeffs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
perspective
(
inpt
,
None
,
...
...
@@ -626,7 +628,8 @@ class ElasticTransform(Transform):
self
.
sigma
=
_setup_float_or_seq
(
sigma
,
"sigma"
,
2
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
size
=
list
(
query_spatial_size
(
flat_inputs
))
...
...
@@ -652,7 +655,7 @@ class ElasticTransform(Transform):
return
dict
(
displacement
=
displacement
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
elastic
(
inpt
,
**
params
,
...
...
torchvision/transforms/v2/_transform.py
View file @
efd6bc06
...
...
@@ -108,30 +108,17 @@ class Transform(nn.Module):
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things:
# 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general
# 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
# v2 transform instance. It extracts all available public attributes that are specific to that transform and
# not `nn.Module` in general.
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs
=
nn
.
Module
().
__dict__
.
keys
()
params
=
{
return
{
attr
:
value
for
attr
,
value
in
self
.
__dict__
.
items
()
if
not
attr
.
startswith
(
"_"
)
and
attr
not
in
common_attrs
}
# transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
# with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
# for the different datapoint types. Below we extract the value for tensors and return that together with the
# other params.
# This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
# `RandomRotation`
if
"fill"
in
params
:
fill_type_defaultdict
=
params
.
pop
(
"fill"
)
params
[
"fill"
]
=
fill_type_defaultdict
[
torch
.
Tensor
]
return
params
def
__prepare_scriptable__
(
self
)
->
nn
.
Module
:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment