Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
edde8255
Unverified
Commit
edde8255
authored
Aug 01, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 01, 2023
Browse files
Allow catch-all 'others' key in fill dicts. Avoid need for defaultdict. (#7779)
parent
312c3d32
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
64 additions
and
69 deletions
+64
-69
gallery/plot_transforms_v2_e2e.py
gallery/plot_transforms_v2_e2e.py
+1
-4
references/segmentation/presets.py
references/segmentation/presets.py
+1
-3
references/segmentation/v2_extras.py
references/segmentation/v2_extras.py
+2
-2
test/test_transforms_v2.py
test/test_transforms_v2.py
+1
-2
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+3
-3
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+3
-3
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+16
-2
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+8
-8
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+15
-14
torchvision/transforms/v2/_utils.py
torchvision/transforms/v2/_utils.py
+14
-28
No files found.
gallery/plot_transforms_v2_e2e.py
View file @
edde8255
...
@@ -10,7 +10,6 @@ well as the new ``torchvision.transforms.v2`` v2 API.
...
@@ -10,7 +10,6 @@ well as the new ``torchvision.transforms.v2`` v2 API.
"""
"""
import
pathlib
import
pathlib
from
collections
import
defaultdict
import
PIL.Image
import
PIL.Image
...
@@ -99,9 +98,7 @@ show(sample)
...
@@ -99,9 +98,7 @@ show(sample)
transform
=
transforms
.
Compose
(
transform
=
transforms
.
Compose
(
[
[
transforms
.
RandomPhotometricDistort
(),
transforms
.
RandomPhotometricDistort
(),
transforms
.
RandomZoomOut
(
transforms
.
RandomZoomOut
(
fill
=
{
PIL
.
Image
.
Image
:
(
123
,
117
,
104
),
"others"
:
0
}),
fill
=
defaultdict
(
lambda
:
0
,
{
PIL
.
Image
.
Image
:
(
123
,
117
,
104
)})
),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToImageTensor
(),
transforms
.
ToImageTensor
(),
...
...
references/segmentation/presets.py
View file @
edde8255
from
collections
import
defaultdict
import
torch
import
torch
...
@@ -48,7 +46,7 @@ class SegmentationPresetTrain:
...
@@ -48,7 +46,7 @@ class SegmentationPresetTrain:
if
use_v2
:
if
use_v2
:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms
+=
[
v2_extras
.
PadIfSmaller
(
crop_size
,
fill
=
defaultdict
(
lambda
:
0
,
{
datapoints
.
Mask
:
255
}
)
)]
transforms
+=
[
v2_extras
.
PadIfSmaller
(
crop_size
,
fill
=
{
datapoints
.
Mask
:
255
,
"others"
:
0
})]
transforms
+=
[
T
.
RandomCrop
(
crop_size
)]
transforms
+=
[
T
.
RandomCrop
(
crop_size
)]
...
...
references/segmentation/v2_extras.py
View file @
edde8255
...
@@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform):
...
@@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform):
def
__init__
(
self
,
size
,
fill
=
0
):
def
__init__
(
self
,
size
,
fill
=
0
):
super
().
__init__
()
super
().
__init__
()
self
.
size
=
size
self
.
size
=
size
self
.
fill
=
v2
.
_
geometry
.
_setup_fill_arg
(
fill
)
self
.
fill
=
v2
.
_
utils
.
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
sample
):
def
_get_params
(
self
,
sample
):
_
,
height
,
width
=
v2
.
utils
.
query_chw
(
sample
)
_
,
height
,
width
=
v2
.
utils
.
query_chw
(
sample
)
...
@@ -20,7 +20,7 @@ class PadIfSmaller(v2.Transform):
...
@@ -20,7 +20,7 @@ class PadIfSmaller(v2.Transform):
if
not
params
[
"needs_padding"
]:
if
not
params
[
"needs_padding"
]:
return
inpt
return
inpt
fill
=
self
.
fill
[
type
(
inpt
)
]
fill
=
v2
.
_utils
.
_get_fill
(
self
.
fill
,
type
(
inpt
)
)
fill
=
v2
.
_utils
.
_convert_fill_arg
(
fill
)
fill
=
v2
.
_utils
.
_convert_fill_arg
(
fill
)
return
v2
.
functional
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
)
return
v2
.
functional
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
)
...
...
test/test_transforms_v2.py
View file @
edde8255
...
@@ -3,7 +3,6 @@ import pathlib
...
@@ -3,7 +3,6 @@ import pathlib
import
random
import
random
import
textwrap
import
textwrap
import
warnings
import
warnings
from
collections
import
defaultdict
import
numpy
as
np
import
numpy
as
np
...
@@ -1475,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1475,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif
data_augmentation
==
"ssd"
:
elif
data_augmentation
==
"ssd"
:
t
=
[
t
=
[
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
defaultdict
(
lambda
:
(
123.0
,
117.0
,
104.0
),
{
datapoints
.
Mask
:
0
}
)
,
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
{
"others"
:
(
123.0
,
117.0
,
104.0
),
datapoints
.
Mask
:
0
},
p
=
1
),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
,
to_tensor
,
...
...
test/test_transforms_v2_consistency.py
View file @
edde8255
...
@@ -4,7 +4,6 @@ import importlib.util
...
@@ -4,7 +4,6 @@ import importlib.util
import
inspect
import
inspect
import
random
import
random
import
re
import
re
from
collections
import
defaultdict
from
pathlib
import
Path
from
pathlib
import
Path
import
numpy
as
np
import
numpy
as
np
...
@@ -30,6 +29,7 @@ from torchvision._utils import sequence_to_str
...
@@ -30,6 +29,7 @@ from torchvision._utils import sequence_to_str
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms.v2
import
functional
as
prototype_F
from
torchvision.transforms.v2
import
functional
as
prototype_F
from
torchvision.transforms.v2._utils
import
_get_fill
from
torchvision.transforms.v2.functional
import
to_image_pil
from
torchvision.transforms.v2.functional
import
to_image_pil
from
torchvision.transforms.v2.utils
import
query_size
from
torchvision.transforms.v2.utils
import
query_size
...
@@ -1181,7 +1181,7 @@ class PadIfSmaller(v2_transforms.Transform):
...
@@ -1181,7 +1181,7 @@ class PadIfSmaller(v2_transforms.Transform):
if
not
params
[
"needs_padding"
]:
if
not
params
[
"needs_padding"
]:
return
inpt
return
inpt
fill
=
self
.
fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
fill
,
type
(
inpt
)
)
return
prototype_F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
)
return
prototype_F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
)
...
@@ -1243,7 +1243,7 @@ class TestRefSegTransforms:
...
@@ -1243,7 +1243,7 @@ class TestRefSegTransforms:
seg_transforms
.
RandomCrop
(
size
=
480
),
seg_transforms
.
RandomCrop
(
size
=
480
),
v2_transforms
.
Compose
(
v2_transforms
.
Compose
(
[
[
PadIfSmaller
(
size
=
480
,
fill
=
defaultdict
(
lambda
:
0
,
{
datapoints
.
Mask
:
255
}
)
),
PadIfSmaller
(
size
=
480
,
fill
=
{
datapoints
.
Mask
:
255
,
"others"
:
0
}),
v2_transforms
.
RandomCrop
(
size
=
480
),
v2_transforms
.
RandomCrop
(
size
=
480
),
]
]
),
),
...
...
torchvision/prototype/transforms/_geometry.py
View file @
edde8255
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2._utils
import
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2.utils
import
has_any
,
is_simple_tensor
,
query_bounding_boxes
,
query_size
from
torchvision.transforms.v2.utils
import
has_any
,
is_simple_tensor
,
query_bounding_boxes
,
query_size
...
@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
...
@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def
__init__
(
def
__init__
(
self
,
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
size
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
padding_mode
:
str
=
"constant"
,
padding_mode
:
str
=
"constant"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -119,7 +119,7 @@ class FixedSizeCrop(Transform):
...
@@ -119,7 +119,7 @@ class FixedSizeCrop(Transform):
)
)
if
params
[
"needs_pad"
]:
if
params
[
"needs_pad"
]:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
inpt
=
F
.
pad
(
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
inpt
=
F
.
pad
(
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
return
inpt
return
inpt
torchvision/prototype/transforms/_misc.py
View file @
edde8255
import
functools
import
warnings
import
warnings
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
import
torch
import
torch
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
Transform
from
torchvision.transforms.v2
import
Transform
from
torchvision.transforms.v2._utils
import
_get_defaultdict
from
torchvision.transforms.v2.utils
import
is_simple_tensor
from
torchvision.transforms.v2.utils
import
is_simple_tensor
T
=
TypeVar
(
"T"
)
def
_default_arg
(
value
:
T
)
->
T
:
return
value
def
_get_defaultdict
(
default
:
T
)
->
Dict
[
Any
,
T
]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return
defaultdict
(
functools
.
partial
(
_default_arg
,
default
))
class
PermuteDimensions
(
Transform
):
class
PermuteDimensions
(
Transform
):
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
edde8255
...
@@ -11,7 +11,7 @@ from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, Interp
...
@@ -11,7 +11,7 @@ from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, Interp
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._meta
import
get_size
from
torchvision.transforms.v2.functional._meta
import
get_size
from
._utils
import
_setup_fill_arg
from
._utils
import
_get_fill
,
_setup_fill_arg
from
.utils
import
check_type
,
is_simple_tensor
from
.utils
import
check_type
,
is_simple_tensor
...
@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform):
...
@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform):
self
,
self
,
*
,
*
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
...
@@ -80,9 +80,9 @@ class _AutoAugmentBase(Transform):
...
@@ -80,9 +80,9 @@ class _AutoAugmentBase(Transform):
transform_id
:
str
,
transform_id
:
str
,
magnitude
:
float
,
magnitude
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
],
interpolation
:
Union
[
InterpolationMode
,
int
],
fill
:
Dict
[
Type
,
datapoints
.
_FillTypeJIT
],
fill
:
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillTypeJIT
],
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]:
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]:
fill_
=
fill
[
type
(
image
)
]
fill_
=
_get_fill
(
fill
,
type
(
image
)
)
if
transform_id
==
"Identity"
:
if
transform_id
==
"Identity"
:
return
image
return
image
...
@@ -214,7 +214,7 @@ class AutoAugment(_AutoAugmentBase):
...
@@ -214,7 +214,7 @@ class AutoAugment(_AutoAugmentBase):
self
,
self
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
policy
=
policy
self
.
policy
=
policy
...
@@ -394,7 +394,7 @@ class RandAugment(_AutoAugmentBase):
...
@@ -394,7 +394,7 @@ class RandAugment(_AutoAugmentBase):
magnitude
:
int
=
9
,
magnitude
:
int
=
9
,
num_magnitude_bins
:
int
=
31
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
num_ops
=
num_ops
self
.
num_ops
=
num_ops
...
@@ -467,7 +467,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
...
@@ -467,7 +467,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self
,
self
,
num_magnitude_bins
:
int
=
31
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
None
,
):
):
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
num_magnitude_bins
=
num_magnitude_bins
self
.
num_magnitude_bins
=
num_magnitude_bins
...
@@ -550,7 +550,7 @@ class AugMix(_AutoAugmentBase):
...
@@ -550,7 +550,7 @@ class AugMix(_AutoAugmentBase):
alpha
:
float
=
1.0
,
alpha
:
float
=
1.0
,
all_ops
:
bool
=
True
,
all_ops
:
bool
=
True
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
_PARAMETER_MAX
=
10
self
.
_PARAMETER_MAX
=
10
...
...
torchvision/transforms/v2/_geometry.py
View file @
edde8255
...
@@ -17,6 +17,7 @@ from ._utils import (
...
@@ -17,6 +17,7 @@ from ._utils import (
_check_padding_arg
,
_check_padding_arg
,
_check_padding_mode_arg
,
_check_padding_mode_arg
,
_check_sequence_input
,
_check_sequence_input
,
_get_fill
,
_setup_angle
,
_setup_angle
,
_setup_fill_arg
,
_setup_fill_arg
,
_setup_float_or_seq
,
_setup_float_or_seq
,
...
@@ -487,7 +488,7 @@ class Pad(Transform):
...
@@ -487,7 +488,7 @@ class Pad(Transform):
def
__init__
(
def
__init__
(
self
,
self
,
padding
:
Union
[
int
,
Sequence
[
int
]],
padding
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -504,7 +505,7 @@ class Pad(Transform):
...
@@ -504,7 +505,7 @@ class Pad(Transform):
self
.
padding_mode
=
padding_mode
self
.
padding_mode
=
padding_mode
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
pad
(
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
return
F
.
pad
(
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
...
@@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform):
...
@@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform):
def
__init__
(
def
__init__
(
self
,
self
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
side_range
:
Sequence
[
float
]
=
(
1.0
,
4.0
),
side_range
:
Sequence
[
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
,
p
:
float
=
0.5
,
)
->
None
:
)
->
None
:
...
@@ -574,7 +575,7 @@ class RandomZoomOut(_RandomApplyTransform):
...
@@ -574,7 +575,7 @@ class RandomZoomOut(_RandomApplyTransform):
return
dict
(
padding
=
padding
)
return
dict
(
padding
=
padding
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
pad
(
inpt
,
**
params
,
fill
=
fill
)
return
F
.
pad
(
inpt
,
**
params
,
fill
=
fill
)
...
@@ -620,7 +621,7 @@ class RandomRotation(Transform):
...
@@ -620,7 +621,7 @@ class RandomRotation(Transform):
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
degrees
=
_setup_angle
(
degrees
,
name
=
"degrees"
,
req_sizes
=
(
2
,))
self
.
degrees
=
_setup_angle
(
degrees
,
name
=
"degrees"
,
req_sizes
=
(
2
,))
...
@@ -640,7 +641,7 @@ class RandomRotation(Transform):
...
@@ -640,7 +641,7 @@ class RandomRotation(Transform):
return
dict
(
angle
=
angle
)
return
dict
(
angle
=
angle
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
rotate
(
return
F
.
rotate
(
inpt
,
inpt
,
**
params
,
**
params
,
...
@@ -702,7 +703,7 @@ class RandomAffine(Transform):
...
@@ -702,7 +703,7 @@ class RandomAffine(Transform):
scale
:
Optional
[
Sequence
[
float
]]
=
None
,
scale
:
Optional
[
Sequence
[
float
]]
=
None
,
shear
:
Optional
[
Union
[
int
,
float
,
Sequence
[
float
]]]
=
None
,
shear
:
Optional
[
Union
[
int
,
float
,
Sequence
[
float
]]]
=
None
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -762,7 +763,7 @@ class RandomAffine(Transform):
...
@@ -762,7 +763,7 @@ class RandomAffine(Transform):
return
dict
(
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
)
return
dict
(
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
affine
(
return
F
.
affine
(
inpt
,
inpt
,
**
params
,
**
params
,
...
@@ -840,7 +841,7 @@ class RandomCrop(Transform):
...
@@ -840,7 +841,7 @@ class RandomCrop(Transform):
size
:
Union
[
int
,
Sequence
[
int
]],
size
:
Union
[
int
,
Sequence
[
int
]],
padding
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
padding
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
pad_if_needed
:
bool
=
False
,
pad_if_needed
:
bool
=
False
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -918,7 +919,7 @@ class RandomCrop(Transform):
...
@@ -918,7 +919,7 @@ class RandomCrop(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"needs_pad"
]:
if
params
[
"needs_pad"
]:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
inpt
=
F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
inpt
=
F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
if
params
[
"needs_crop"
]:
if
params
[
"needs_crop"
]:
...
@@ -959,7 +960,7 @@ class RandomPerspective(_RandomApplyTransform):
...
@@ -959,7 +960,7 @@ class RandomPerspective(_RandomApplyTransform):
distortion_scale
:
float
=
0.5
,
distortion_scale
:
float
=
0.5
,
p
:
float
=
0.5
,
p
:
float
=
0.5
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
)
->
None
:
)
->
None
:
super
().
__init__
(
p
=
p
)
super
().
__init__
(
p
=
p
)
...
@@ -1002,7 +1003,7 @@ class RandomPerspective(_RandomApplyTransform):
...
@@ -1002,7 +1003,7 @@ class RandomPerspective(_RandomApplyTransform):
return
dict
(
coefficients
=
perspective_coeffs
)
return
dict
(
coefficients
=
perspective_coeffs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
perspective
(
return
F
.
perspective
(
inpt
,
inpt
,
None
,
None
,
...
@@ -1061,7 +1062,7 @@ class ElasticTransform(Transform):
...
@@ -1061,7 +1062,7 @@ class ElasticTransform(Transform):
alpha
:
Union
[
float
,
Sequence
[
float
]]
=
50.0
,
alpha
:
Union
[
float
,
Sequence
[
float
]]
=
50.0
,
sigma
:
Union
[
float
,
Sequence
[
float
]]
=
5.0
,
sigma
:
Union
[
float
,
Sequence
[
float
]]
=
5.0
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
alpha
=
_setup_float_or_seq
(
alpha
,
"alpha"
,
2
)
self
.
alpha
=
_setup_float_or_seq
(
alpha
,
"alpha"
,
2
)
...
@@ -1095,7 +1096,7 @@ class ElasticTransform(Transform):
...
@@ -1095,7 +1096,7 @@ class ElasticTransform(Transform):
return
dict
(
displacement
=
displacement
)
return
dict
(
displacement
=
displacement
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
elastic
(
return
F
.
elastic
(
inpt
,
inpt
,
**
params
,
**
params
,
...
...
torchvision/transforms/v2/_utils.py
View file @
edde8255
import
collections.abc
import
collections.abc
import
functools
import
numbers
import
numbers
from
collections
import
defaultdict
from
contextlib
import
suppress
from
contextlib
import
suppress
from
typing
import
Any
,
Callable
,
Dict
,
Literal
,
Optional
,
Sequence
,
Type
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Literal
,
Optional
,
Sequence
,
Type
,
Union
import
torch
import
torch
...
@@ -29,32 +27,15 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
...
@@ -29,32 +27,15 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
return
arg
return
arg
def
_check_fill_arg
(
fill
:
Union
[
_FillType
,
Dict
[
Type
,
_FillType
]])
->
None
:
def
_check_fill_arg
(
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
_FillType
]])
->
None
:
if
isinstance
(
fill
,
dict
):
if
isinstance
(
fill
,
dict
):
for
key
,
value
in
fill
.
items
():
for
value
in
fill
.
values
():
# Check key for type
_check_fill_arg
(
value
)
_check_fill_arg
(
value
)
if
isinstance
(
fill
,
defaultdict
)
and
callable
(
fill
.
default_factory
):
default_value
=
fill
.
default_factory
()
_check_fill_arg
(
default_value
)
else
:
else
:
if
fill
is
not
None
and
not
isinstance
(
fill
,
(
numbers
.
Number
,
tuple
,
list
)):
if
fill
is
not
None
and
not
isinstance
(
fill
,
(
numbers
.
Number
,
tuple
,
list
)):
raise
TypeError
(
"Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed."
)
raise
TypeError
(
"Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed."
)
T
=
TypeVar
(
"T"
)
def
_default_arg
(
value
:
T
)
->
T
:
return
value
def
_get_defaultdict
(
default
:
T
)
->
Dict
[
Any
,
T
]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return
defaultdict
(
functools
.
partial
(
_default_arg
,
default
))
def
_convert_fill_arg
(
fill
:
datapoints
.
_FillType
)
->
datapoints
.
_FillTypeJIT
:
def
_convert_fill_arg
(
fill
:
datapoints
.
_FillType
)
->
datapoints
.
_FillTypeJIT
:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# So, we can't reassign fill to 0
...
@@ -68,19 +49,24 @@ def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
...
@@ -68,19 +49,24 @@ def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
return
fill
# type: ignore[return-value]
return
fill
# type: ignore[return-value]
def
_setup_fill_arg
(
fill
:
Union
[
_FillType
,
Dict
[
Type
,
_FillType
]])
->
Dict
[
Type
,
_FillTypeJIT
]:
def
_setup_fill_arg
(
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
_FillType
]])
->
Dict
[
Union
[
Type
,
str
]
,
_FillTypeJIT
]:
_check_fill_arg
(
fill
)
_check_fill_arg
(
fill
)
if
isinstance
(
fill
,
dict
):
if
isinstance
(
fill
,
dict
):
for
k
,
v
in
fill
.
items
():
for
k
,
v
in
fill
.
items
():
fill
[
k
]
=
_convert_fill_arg
(
v
)
fill
[
k
]
=
_convert_fill_arg
(
v
)
if
isinstance
(
fill
,
defaultdict
)
and
callable
(
fill
.
default_factory
):
default_value
=
fill
.
default_factory
()
sanitized_default
=
_convert_fill_arg
(
default_value
)
fill
.
default_factory
=
functools
.
partial
(
_default_arg
,
sanitized_default
)
return
fill
# type: ignore[return-value]
return
fill
# type: ignore[return-value]
else
:
return
{
"others"
:
_convert_fill_arg
(
fill
)}
return
_get_defaultdict
(
_convert_fill_arg
(
fill
))
def
_get_fill
(
fill_dict
,
inpt_type
):
if
inpt_type
in
fill_dict
:
return
fill_dict
[
inpt_type
]
elif
"others"
in
fill_dict
:
return
fill_dict
[
"others"
]
else
:
RuntimeError
(
"This should never happen, please open an issue on the torchvision repo if you hit this."
)
def
_check_padding_arg
(
padding
:
Union
[
int
,
Sequence
[
int
]])
->
None
:
def
_check_padding_arg
(
padding
:
Union
[
int
,
Sequence
[
int
]])
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment