Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
efd6bc06
"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "cc164478f3031cc414f4938c3ec35e113063fded"
Unverified
Commit
efd6bc06
authored
Feb 16, 2023
by
Philip Meier
Committed by
GitHub
Feb 16, 2023
Browse files
make fill defaultdict an implementation detail (#7258)
parent
b7892d3a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
38 deletions
+29
-38
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+3
-2
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+23
-20
torchvision/transforms/v2/_transform.py
torchvision/transforms/v2/_transform.py
+3
-16
No files found.
torchvision/prototype/transforms/_geometry.py
View file @
efd6bc06
...
...
@@ -22,7 +22,8 @@ class FixedSizeCrop(Transform):
self
.
crop_height
=
size
[
0
]
self
.
crop_width
=
size
[
1
]
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
self
.
padding_mode
=
padding_mode
...
...
@@ -118,7 +119,7 @@ class FixedSizeCrop(Transform):
)
if
params
[
"needs_pad"
]:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
inpt
=
F
.
pad
(
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
return
inpt
torchvision/transforms/v2/_geometry.py
View file @
efd6bc06
...
...
@@ -255,9 +255,7 @@ class Pad(Transform):
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
f
"
{
type
(
self
.
__name__
)
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
for images."
)
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
."
)
return
params
...
...
@@ -276,11 +274,12 @@ class Pad(Transform):
if
not
isinstance
(
padding
,
int
):
padding
=
list
(
padding
)
self
.
padding
=
padding
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
self
.
padding_mode
=
padding_mode
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
pad
(
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
...
...
@@ -293,7 +292,8 @@ class RandomZoomOut(_RandomApplyTransform):
)
->
None
:
super
().
__init__
(
p
=
p
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
_check_sequence_input
(
side_range
,
"side_range"
,
req_sizes
=
(
2
,))
...
...
@@ -318,7 +318,7 @@ class RandomZoomOut(_RandomApplyTransform):
return
dict
(
padding
=
padding
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
pad
(
inpt
,
**
params
,
fill
=
fill
)
...
...
@@ -338,7 +338,8 @@ class RandomRotation(Transform):
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
expand
=
expand
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
if
center
is
not
None
:
_check_sequence_input
(
center
,
"center"
,
req_sizes
=
(
2
,))
...
...
@@ -350,7 +351,7 @@ class RandomRotation(Transform):
return
dict
(
angle
=
angle
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
rotate
(
inpt
,
**
params
,
...
...
@@ -395,7 +396,8 @@ class RandomAffine(Transform):
self
.
shear
=
shear
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
if
center
is
not
None
:
_check_sequence_input
(
center
,
"center"
,
req_sizes
=
(
2
,))
...
...
@@ -430,7 +432,7 @@ class RandomAffine(Transform):
return
dict
(
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
affine
(
inpt
,
**
params
,
...
...
@@ -447,9 +449,7 @@ class RandomCrop(Transform):
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
f
"
{
type
(
self
.
__name__
)
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
for images."
)
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
."
)
padding
=
self
.
padding
if
padding
is
not
None
:
...
...
@@ -478,7 +478,8 @@ class RandomCrop(Transform):
self
.
padding
=
F
.
_geometry
.
_parse_pad_padding
(
padding
)
if
padding
else
None
# type: ignore[arg-type]
self
.
pad_if_needed
=
pad_if_needed
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
self
.
padding_mode
=
padding_mode
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
...
...
@@ -541,7 +542,7 @@ class RandomCrop(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"needs_pad"
]:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
inpt
=
F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
if
params
[
"needs_crop"
]:
...
...
@@ -567,7 +568,8 @@ class RandomPerspective(_RandomApplyTransform):
self
.
distortion_scale
=
distortion_scale
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
height
,
width
=
query_spatial_size
(
flat_inputs
)
...
...
@@ -600,7 +602,7 @@ class RandomPerspective(_RandomApplyTransform):
return
dict
(
coefficients
=
perspective_coeffs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
perspective
(
inpt
,
None
,
...
...
@@ -626,7 +628,8 @@ class ElasticTransform(Transform):
self
.
sigma
=
_setup_float_or_seq
(
sigma
,
"sigma"
,
2
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
size
=
list
(
query_spatial_size
(
flat_inputs
))
...
...
@@ -652,7 +655,7 @@ class ElasticTransform(Transform):
return
dict
(
displacement
=
displacement
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
fill
[
type
(
inpt
)]
fill
=
self
.
_
fill
[
type
(
inpt
)]
return
F
.
elastic
(
inpt
,
**
params
,
...
...
torchvision/transforms/v2/_transform.py
View file @
efd6bc06
...
...
@@ -108,30 +108,17 @@ class Transform(nn.Module):
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things:
# 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general
# 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
# v2 transform instance. It extracts all available public attributes that are specific to that transform and
# not `nn.Module` in general.
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs
=
nn
.
Module
().
__dict__
.
keys
()
params
=
{
return
{
attr
:
value
for
attr
,
value
in
self
.
__dict__
.
items
()
if
not
attr
.
startswith
(
"_"
)
and
attr
not
in
common_attrs
}
# transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
# with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
# for the different datapoint types. Below we extract the value for tensors and return that together with the
# other params.
# This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
# `RandomRotation`
if
"fill"
in
params
:
fill_type_defaultdict
=
params
.
pop
(
"fill"
)
params
[
"fill"
]
=
fill_type_defaultdict
[
torch
.
Tensor
]
return
params
def
__prepare_scriptable__
(
self
)
->
nn
.
Module
:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment