Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
edde8255
Unverified
Commit
edde8255
authored
Aug 01, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 01, 2023
Browse files
Allow catch-all 'others' key in fill dicts. Avoid need for defaultdict. (#7779)
parent
312c3d32
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
64 additions
and
69 deletions
+64
-69
gallery/plot_transforms_v2_e2e.py
gallery/plot_transforms_v2_e2e.py
+1
-4
references/segmentation/presets.py
references/segmentation/presets.py
+1
-3
references/segmentation/v2_extras.py
references/segmentation/v2_extras.py
+2
-2
test/test_transforms_v2.py
test/test_transforms_v2.py
+1
-2
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+3
-3
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+3
-3
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+16
-2
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+8
-8
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+15
-14
torchvision/transforms/v2/_utils.py
torchvision/transforms/v2/_utils.py
+14
-28
No files found.
gallery/plot_transforms_v2_e2e.py
View file @
edde8255
...
...
@@ -10,7 +10,6 @@ well as the new ``torchvision.transforms.v2`` v2 API.
"""
import
pathlib
from
collections
import
defaultdict
import
PIL.Image
...
...
@@ -99,9 +98,7 @@ show(sample)
transform
=
transforms
.
Compose
(
[
transforms
.
RandomPhotometricDistort
(),
transforms
.
RandomZoomOut
(
fill
=
defaultdict
(
lambda
:
0
,
{
PIL
.
Image
.
Image
:
(
123
,
117
,
104
)})
),
transforms
.
RandomZoomOut
(
fill
=
{
PIL
.
Image
.
Image
:
(
123
,
117
,
104
),
"others"
:
0
}),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToImageTensor
(),
...
...
references/segmentation/presets.py
View file @
edde8255
from
collections
import
defaultdict
import
torch
...
...
@@ -48,7 +46,7 @@ class SegmentationPresetTrain:
if
use_v2
:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms
+=
[
v2_extras
.
PadIfSmaller
(
crop_size
,
fill
=
defaultdict
(
lambda
:
0
,
{
datapoints
.
Mask
:
255
}
)
)]
transforms
+=
[
v2_extras
.
PadIfSmaller
(
crop_size
,
fill
=
{
datapoints
.
Mask
:
255
,
"others"
:
0
})]
transforms
+=
[
T
.
RandomCrop
(
crop_size
)]
...
...
references/segmentation/v2_extras.py
View file @
edde8255
...
...
@@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform):
def
__init__
(
self
,
size
,
fill
=
0
):
super
().
__init__
()
self
.
size
=
size
self
.
fill
=
v2
.
_
geometry
.
_setup_fill_arg
(
fill
)
self
.
fill
=
v2
.
_
utils
.
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
sample
):
_
,
height
,
width
=
v2
.
utils
.
query_chw
(
sample
)
...
...
@@ -20,7 +20,7 @@ class PadIfSmaller(v2.Transform):
if
not
params
[
"needs_padding"
]:
return
inpt
fill
=
self
.
fill
[
type
(
inpt
)
]
fill
=
v2
.
_utils
.
_get_fill
(
self
.
fill
,
type
(
inpt
)
)
fill
=
v2
.
_utils
.
_convert_fill_arg
(
fill
)
return
v2
.
functional
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
)
...
...
test/test_transforms_v2.py
View file @
edde8255
...
...
@@ -3,7 +3,6 @@ import pathlib
import
random
import
textwrap
import
warnings
from
collections
import
defaultdict
import
numpy
as
np
...
...
@@ -1475,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif
data_augmentation
==
"ssd"
:
t
=
[
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
defaultdict
(
lambda
:
(
123.0
,
117.0
,
104.0
),
{
datapoints
.
Mask
:
0
}
)
,
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
{
"others"
:
(
123.0
,
117.0
,
104.0
),
datapoints
.
Mask
:
0
},
p
=
1
),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
,
...
...
test/test_transforms_v2_consistency.py
View file @
edde8255
...
...
@@ -4,7 +4,6 @@ import importlib.util
import
inspect
import
random
import
re
from
collections
import
defaultdict
from
pathlib
import
Path
import
numpy
as
np
...
...
@@ -30,6 +29,7 @@ from torchvision._utils import sequence_to_str
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms.v2
import
functional
as
prototype_F
from
torchvision.transforms.v2._utils
import
_get_fill
from
torchvision.transforms.v2.functional
import
to_image_pil
from
torchvision.transforms.v2.utils
import
query_size
...
...
@@ -1181,7 +1181,7 @@ class PadIfSmaller(v2_transforms.Transform):
if
not
params
[
"needs_padding"
]:
return
inpt
fill
=
self
.
fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
fill
,
type
(
inpt
)
)
return
prototype_F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
)
...
...
@@ -1243,7 +1243,7 @@ class TestRefSegTransforms:
seg_transforms
.
RandomCrop
(
size
=
480
),
v2_transforms
.
Compose
(
[
PadIfSmaller
(
size
=
480
,
fill
=
defaultdict
(
lambda
:
0
,
{
datapoints
.
Mask
:
255
}
)
),
PadIfSmaller
(
size
=
480
,
fill
=
{
datapoints
.
Mask
:
255
,
"others"
:
0
}),
v2_transforms
.
RandomCrop
(
size
=
480
),
]
),
...
...
torchvision/prototype/transforms/_geometry.py
View file @
edde8255
...
...
@@ -6,7 +6,7 @@ import torch
from
torchvision
import
datapoints
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2._utils
import
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2.utils
import
has_any
,
is_simple_tensor
,
query_bounding_boxes
,
query_size
...
...
@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
padding_mode
:
str
=
"constant"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -119,7 +119,7 @@ class FixedSizeCrop(Transform):
)
if
params
[
"needs_pad"
]:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
inpt
=
F
.
pad
(
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
return
inpt
torchvision/prototype/transforms/_misc.py
View file @
edde8255
import
functools
import
warnings
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
import
torch
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
Transform
from
torchvision.transforms.v2._utils
import
_get_defaultdict
from
torchvision.transforms.v2.utils
import
is_simple_tensor
T
=
TypeVar
(
"T"
)
def
_default_arg
(
value
:
T
)
->
T
:
return
value
def
_get_defaultdict
(
default
:
T
)
->
Dict
[
Any
,
T
]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return
defaultdict
(
functools
.
partial
(
_default_arg
,
default
))
class
PermuteDimensions
(
Transform
):
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
edde8255
...
...
@@ -11,7 +11,7 @@ from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, Interp
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._meta
import
get_size
from
._utils
import
_setup_fill_arg
from
._utils
import
_get_fill
,
_setup_fill_arg
from
.utils
import
check_type
,
is_simple_tensor
...
...
@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform):
self
,
*
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -80,9 +80,9 @@ class _AutoAugmentBase(Transform):
transform_id
:
str
,
magnitude
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
],
fill
:
Dict
[
Type
,
datapoints
.
_FillTypeJIT
],
fill
:
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillTypeJIT
],
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]:
fill_
=
fill
[
type
(
image
)
]
fill_
=
_get_fill
(
fill
,
type
(
image
)
)
if
transform_id
==
"Identity"
:
return
image
...
...
@@ -214,7 +214,7 @@ class AutoAugment(_AutoAugmentBase):
self
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
policy
=
policy
...
...
@@ -394,7 +394,7 @@ class RandAugment(_AutoAugmentBase):
magnitude
:
int
=
9
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
num_ops
=
num_ops
...
...
@@ -467,7 +467,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
None
,
):
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
num_magnitude_bins
=
num_magnitude_bins
...
...
@@ -550,7 +550,7 @@ class AugMix(_AutoAugmentBase):
alpha
:
float
=
1.0
,
all_ops
:
bool
=
True
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
_PARAMETER_MAX
=
10
...
...
torchvision/transforms/v2/_geometry.py
View file @
edde8255
...
...
@@ -17,6 +17,7 @@ from ._utils import (
_check_padding_arg
,
_check_padding_mode_arg
,
_check_sequence_input
,
_get_fill
,
_setup_angle
,
_setup_fill_arg
,
_setup_float_or_seq
,
...
...
@@ -487,7 +488,7 @@ class Pad(Transform):
def
__init__
(
self
,
padding
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -504,7 +505,7 @@ class Pad(Transform):
self
.
padding_mode
=
padding_mode
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
pad
(
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
...
...
@@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform):
def
__init__
(
self
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
side_range
:
Sequence
[
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
,
)
->
None
:
...
...
@@ -574,7 +575,7 @@ class RandomZoomOut(_RandomApplyTransform):
return
dict
(
padding
=
padding
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
pad
(
inpt
,
**
params
,
fill
=
fill
)
...
...
@@ -620,7 +621,7 @@ class RandomRotation(Transform):
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
)
->
None
:
super
().
__init__
()
self
.
degrees
=
_setup_angle
(
degrees
,
name
=
"degrees"
,
req_sizes
=
(
2
,))
...
...
@@ -640,7 +641,7 @@ class RandomRotation(Transform):
return
dict
(
angle
=
angle
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
rotate
(
inpt
,
**
params
,
...
...
@@ -702,7 +703,7 @@ class RandomAffine(Transform):
scale
:
Optional
[
Sequence
[
float
]]
=
None
,
shear
:
Optional
[
Union
[
int
,
float
,
Sequence
[
float
]]]
=
None
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -762,7 +763,7 @@ class RandomAffine(Transform):
return
dict
(
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
affine
(
inpt
,
**
params
,
...
...
@@ -840,7 +841,7 @@ class RandomCrop(Transform):
size
:
Union
[
int
,
Sequence
[
int
]],
padding
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
pad_if_needed
:
bool
=
False
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -918,7 +919,7 @@ class RandomCrop(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"needs_pad"
]:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
inpt
=
F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
if
params
[
"needs_crop"
]:
...
...
@@ -959,7 +960,7 @@ class RandomPerspective(_RandomApplyTransform):
distortion_scale
:
float
=
0.5
,
p
:
float
=
0.5
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
)
->
None
:
super
().
__init__
(
p
=
p
)
...
...
@@ -1002,7 +1003,7 @@ class RandomPerspective(_RandomApplyTransform):
return
dict
(
coefficients
=
perspective_coeffs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
perspective
(
inpt
,
None
,
...
...
@@ -1061,7 +1062,7 @@ class ElasticTransform(Transform):
alpha
:
Union
[
float
,
Sequence
[
float
]]
=
50.0
,
sigma
:
Union
[
float
,
Sequence
[
float
]]
=
5.0
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Type
,
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
datapoints
.
_FillType
]]
=
0
,
)
->
None
:
super
().
__init__
()
self
.
alpha
=
_setup_float_or_seq
(
alpha
,
"alpha"
,
2
)
...
...
@@ -1095,7 +1096,7 @@ class ElasticTransform(Transform):
return
dict
(
displacement
=
displacement
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
self
.
_fill
[
type
(
inpt
)
]
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
)
)
return
F
.
elastic
(
inpt
,
**
params
,
...
...
torchvision/transforms/v2/_utils.py
View file @
edde8255
import
collections.abc
import
functools
import
numbers
from
collections
import
defaultdict
from
contextlib
import
suppress
from
typing
import
Any
,
Callable
,
Dict
,
Literal
,
Optional
,
Sequence
,
Type
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Literal
,
Optional
,
Sequence
,
Type
,
Union
import
torch
...
...
@@ -29,32 +27,15 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
return
arg
def
_check_fill_arg
(
fill
:
Union
[
_FillType
,
Dict
[
Type
,
_FillType
]])
->
None
:
def
_check_fill_arg
(
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
_FillType
]])
->
None
:
if
isinstance
(
fill
,
dict
):
for
key
,
value
in
fill
.
items
():
# Check key for type
for
value
in
fill
.
values
():
_check_fill_arg
(
value
)
if
isinstance
(
fill
,
defaultdict
)
and
callable
(
fill
.
default_factory
):
default_value
=
fill
.
default_factory
()
_check_fill_arg
(
default_value
)
else
:
if
fill
is
not
None
and
not
isinstance
(
fill
,
(
numbers
.
Number
,
tuple
,
list
)):
raise
TypeError
(
"Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed."
)
T
=
TypeVar
(
"T"
)
def
_default_arg
(
value
:
T
)
->
T
:
return
value
def
_get_defaultdict
(
default
:
T
)
->
Dict
[
Any
,
T
]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return
defaultdict
(
functools
.
partial
(
_default_arg
,
default
))
def
_convert_fill_arg
(
fill
:
datapoints
.
_FillType
)
->
datapoints
.
_FillTypeJIT
:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
...
...
@@ -68,19 +49,24 @@ def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
return
fill
# type: ignore[return-value]
def
_setup_fill_arg
(
fill
:
Union
[
_FillType
,
Dict
[
Type
,
_FillType
]])
->
Dict
[
Type
,
_FillTypeJIT
]:
def
_setup_fill_arg
(
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
]
,
_FillType
]])
->
Dict
[
Union
[
Type
,
str
]
,
_FillTypeJIT
]:
_check_fill_arg
(
fill
)
if
isinstance
(
fill
,
dict
):
for
k
,
v
in
fill
.
items
():
fill
[
k
]
=
_convert_fill_arg
(
v
)
if
isinstance
(
fill
,
defaultdict
)
and
callable
(
fill
.
default_factory
):
default_value
=
fill
.
default_factory
()
sanitized_default
=
_convert_fill_arg
(
default_value
)
fill
.
default_factory
=
functools
.
partial
(
_default_arg
,
sanitized_default
)
return
fill
# type: ignore[return-value]
else
:
return
{
"others"
:
_convert_fill_arg
(
fill
)}
return
_get_defaultdict
(
_convert_fill_arg
(
fill
))
def
_get_fill
(
fill_dict
,
inpt_type
):
if
inpt_type
in
fill_dict
:
return
fill_dict
[
inpt_type
]
elif
"others"
in
fill_dict
:
return
fill_dict
[
"others"
]
else
:
RuntimeError
(
"This should never happen, please open an issue on the torchvision repo if you hit this."
)
def
_check_padding_arg
(
padding
:
Union
[
int
,
Sequence
[
int
]])
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment