Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
efd6bc06
Unverified
Commit
efd6bc06
authored
Feb 16, 2023
by
Philip Meier
Committed by
GitHub
Feb 16, 2023
Browse files
make fill defaultdict an implementation detail (#7258)
parent
b7892d3a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
38 deletions
+29
-38
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+3
-2
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+23
-20
torchvision/transforms/v2/_transform.py
torchvision/transforms/v2/_transform.py
+3
-16
No files found.
torchvision/prototype/transforms/_geometry.py
View file @
efd6bc06
...
@@ -22,7 +22,8 @@ class FixedSizeCrop(Transform):
...
@@ -22,7 +22,8 @@ class FixedSizeCrop(Transform):
self
.
crop_height
=
size
[
0
]
self
.
crop_height
=
size
[
0
]
self
.
crop_width
=
size
[
1
]
self
.
crop_width
=
size
[
1
]
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
self
.
padding_mode
=
padding_mode
self
.
padding_mode
=
padding_mode
...
@@ -118,7 +119,7 @@ class FixedSizeCrop(Transform):
...
@@ -118,7 +119,7 @@ class FixedSizeCrop(Transform):
)
)
if
params
[
"needs_pad"
]:
if
params
[
"needs_pad"
]:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
inpt
=
F
.
pad
(
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
inpt
=
F
.
pad
(
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
return
inpt
return
inpt
torchvision/transforms/v2/_geometry.py
View file @
efd6bc06
...
@@ -255,9 +255,7 @@ class Pad(Transform):
...
@@ -255,9 +255,7 @@ class Pad(Transform):
params
=
super
().
_extract_params_for_v1_transform
()
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
."
)
f
"
{
type
(
self
.
__name__
)
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
for images."
)
return
params
return
params
...
@@ -276,11 +274,12 @@ class Pad(Transform):
...
@@ -276,11 +274,12 @@ class Pad(Transform):
if
not
isinstance
(
padding
,
int
):
if
not
isinstance
(
padding
,
int
):
padding
=
list
(
padding
)
padding
=
list
(
padding
)
self
.
padding
=
padding
self
.
padding
=
padding
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
self
.
padding_mode
=
padding_mode
self
.
padding_mode
=
padding_mode
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
pad
(
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
return
F
.
pad
(
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
...
@@ -293,7 +292,8 @@ class RandomZoomOut(_RandomApplyTransform):
...
@@ -293,7 +292,8 @@ class RandomZoomOut(_RandomApplyTransform):
)
->
None
:
)
->
None
:
super
().
__init__
(
p
=
p
)
super
().
__init__
(
p
=
p
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
_check_sequence_input
(
side_range
,
"side_range"
,
req_sizes
=
(
2
,))
_check_sequence_input
(
side_range
,
"side_range"
,
req_sizes
=
(
2
,))
...
@@ -318,7 +318,7 @@ class RandomZoomOut(_RandomApplyTransform):
...
@@ -318,7 +318,7 @@ class RandomZoomOut(_RandomApplyTransform):
return
dict
(
padding
=
padding
)
return
dict
(
padding
=
padding
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
pad
(
inpt
,
**
params
,
fill
=
fill
)
return
F
.
pad
(
inpt
,
**
params
,
fill
=
fill
)
...
@@ -338,7 +338,8 @@ class RandomRotation(Transform):
...
@@ -338,7 +338,8 @@ class RandomRotation(Transform):
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
expand
=
expand
self
.
expand
=
expand
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
if
center
is
not
None
:
if
center
is
not
None
:
_check_sequence_input
(
center
,
"center"
,
req_sizes
=
(
2
,))
_check_sequence_input
(
center
,
"center"
,
req_sizes
=
(
2
,))
...
@@ -350,7 +351,7 @@ class RandomRotation(Transform):
...
@@ -350,7 +351,7 @@ class RandomRotation(Transform):
return
dict
(
angle
=
angle
)
return
dict
(
angle
=
angle
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
rotate
(
return
F
.
rotate
(
inpt
,
inpt
,
**
params
,
**
params
,
...
@@ -395,7 +396,8 @@ class RandomAffine(Transform):
...
@@ -395,7 +396,8 @@ class RandomAffine(Transform):
self
.
shear
=
shear
self
.
shear
=
shear
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
if
center
is
not
None
:
if
center
is
not
None
:
_check_sequence_input
(
center
,
"center"
,
req_sizes
=
(
2
,))
_check_sequence_input
(
center
,
"center"
,
req_sizes
=
(
2
,))
...
@@ -430,7 +432,7 @@ class RandomAffine(Transform):
...
@@ -430,7 +432,7 @@ class RandomAffine(Transform):
return
dict
(
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
)
return
dict
(
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
affine
(
return
F
.
affine
(
inpt
,
inpt
,
**
params
,
**
params
,
...
@@ -447,9 +449,7 @@ class RandomCrop(Transform):
...
@@ -447,9 +449,7 @@ class RandomCrop(Transform):
params
=
super
().
_extract_params_for_v1_transform
()
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
."
)
f
"
{
type
(
self
.
__name__
)
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
for images."
)
padding
=
self
.
padding
padding
=
self
.
padding
if
padding
is
not
None
:
if
padding
is
not
None
:
...
@@ -478,7 +478,8 @@ class RandomCrop(Transform):
...
@@ -478,7 +478,8 @@ class RandomCrop(Transform):
self
.
padding
=
F
.
_geometry
.
_parse_pad_padding
(
padding
)
if
padding
else
None
# type: ignore[arg-type]
self
.
padding
=
F
.
_geometry
.
_parse_pad_padding
(
padding
)
if
padding
else
None
# type: ignore[arg-type]
self
.
pad_if_needed
=
pad_if_needed
self
.
pad_if_needed
=
pad_if_needed
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
self
.
padding_mode
=
padding_mode
self
.
padding_mode
=
padding_mode
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
...
@@ -541,7 +542,7 @@ class RandomCrop(Transform):
...
@@ -541,7 +542,7 @@ class RandomCrop(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"needs_pad"
]:
if
params
[
"needs_pad"
]:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
inpt
=
F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
inpt
=
F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
if
params
[
"needs_crop"
]:
if
params
[
"needs_crop"
]:
...
@@ -567,7 +568,8 @@ class RandomPerspective(_RandomApplyTransform):
...
@@ -567,7 +568,8 @@ class RandomPerspective(_RandomApplyTransform):
self
.
distortion_scale
=
distortion_scale
self
.
distortion_scale
=
distortion_scale
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
height
,
width
=
query_spatial_size
(
flat_inputs
)
height
,
width
=
query_spatial_size
(
flat_inputs
)
...
@@ -600,7 +602,7 @@ class RandomPerspective(_RandomApplyTransform):
...
@@ -600,7 +602,7 @@ class RandomPerspective(_RandomApplyTransform):
return
dict
(
coefficients
=
perspective_coeffs
)
return
dict
(
coefficients
=
perspective_coeffs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
perspective
(
return
F
.
perspective
(
inpt
,
inpt
,
None
,
None
,
...
@@ -626,7 +628,8 @@ class ElasticTransform(Transform):
...
@@ -626,7 +628,8 @@ class ElasticTransform(Transform):
self
.
sigma
=
_setup_float_or_seq
(
sigma
,
"sigma"
,
2
)
self
.
sigma
=
_setup_float_or_seq
(
sigma
,
"sigma"
,
2
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
size
=
list
(
query_spatial_size
(
flat_inputs
))
size
=
list
(
query_spatial_size
(
flat_inputs
))
...
@@ -652,7 +655,7 @@ class ElasticTransform(Transform):
...
@@ -652,7 +655,7 @@ class ElasticTransform(Transform):
return
dict
(
displacement
=
displacement
)
return
dict
(
displacement
=
displacement
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
elastic
(
return
F
.
elastic
(
inpt
,
inpt
,
**
params
,
**
params
,
...
...
torchvision/transforms/v2/_transform.py
View file @
efd6bc06
...
@@ -108,30 +108,17 @@ class Transform(nn.Module):
...
@@ -108,30 +108,17 @@ class Transform(nn.Module):
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things:
# v2 transform instance. It extracts all available public attributes that are specific to that transform and
# 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general
# not `nn.Module` in general.
# 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform.
# if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs
=
nn
.
Module
().
__dict__
.
keys
()
common_attrs
=
nn
.
Module
().
__dict__
.
keys
()
params
=
{
return
{
attr
:
value
attr
:
value
for
attr
,
value
in
self
.
__dict__
.
items
()
for
attr
,
value
in
self
.
__dict__
.
items
()
if
not
attr
.
startswith
(
"_"
)
and
attr
not
in
common_attrs
if
not
attr
.
startswith
(
"_"
)
and
attr
not
in
common_attrs
}
}
# transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
# with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
# for the different datapoint types. Below we extract the value for tensors and return that together with the
# other params.
# This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
# `RandomRotation`
if
"fill"
in
params
:
fill_type_defaultdict
=
params
.
pop
(
"fill"
)
params
[
"fill"
]
=
fill_type_defaultdict
[
torch
.
Tensor
]
return
params
def
__prepare_scriptable__
(
self
)
->
nn
.
Module
:
def
__prepare_scriptable__
(
self
)
->
nn
.
Module
:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment