Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
0e0a5dc7
Unverified
Commit
0e0a5dc7
authored
Feb 15, 2023
by
Philip Meier
Committed by
GitHub
Feb 15, 2023
Browse files
Support integer values for interpolation in the prototype transforms (#7248)
parent
f627b9d1
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
153 additions
and
96 deletions
+153
-96
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+3
-3
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+36
-10
torchvision/prototype/datapoints/_bounding_box.py
torchvision/prototype/datapoints/_bounding_box.py
+6
-6
torchvision/prototype/datapoints/_datapoint.py
torchvision/prototype/datapoints/_datapoint.py
+6
-6
torchvision/prototype/datapoints/_image.py
torchvision/prototype/datapoints/_image.py
+6
-6
torchvision/prototype/datapoints/_mask.py
torchvision/prototype/datapoints/_mask.py
+6
-6
torchvision/prototype/datapoints/_video.py
torchvision/prototype/datapoints/_video.py
+6
-6
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+3
-2
torchvision/prototype/transforms/_auto_augment.py
torchvision/prototype/transforms/_auto_augment.py
+8
-7
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+19
-18
torchvision/prototype/transforms/_presets.py
torchvision/prototype/transforms/_presets.py
+4
-2
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+50
-24
No files found.
test/test_prototype_transforms.py
View file @
0e0a5dc7
...
...
@@ -1534,7 +1534,7 @@ class TestScaleJitter:
assert
int
(
spatial_size
[
1
]
*
r_min
)
<=
width
<=
int
(
spatial_size
[
1
]
*
r_max
)
def
test__transform
(
self
,
mocker
):
interpolation_sentinel
=
mocker
.
MagicMock
()
interpolation_sentinel
=
mocker
.
MagicMock
(
spec
=
InterpolationMode
)
antialias_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
ScaleJitter
(
...
...
@@ -1581,7 +1581,7 @@ class TestRandomShortestSize:
assert
shorter
in
min_size
def
test__transform
(
self
,
mocker
):
interpolation_sentinel
=
mocker
.
MagicMock
()
interpolation_sentinel
=
mocker
.
MagicMock
(
spec
=
InterpolationMode
)
antialias_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
RandomShortestSize
(
...
...
@@ -1945,7 +1945,7 @@ class TestRandomResize:
assert
min_size
<=
size
<
max_size
def
test__transform
(
self
,
mocker
):
interpolation_sentinel
=
mocker
.
MagicMock
()
interpolation_sentinel
=
mocker
.
MagicMock
(
spec
=
InterpolationMode
)
antialias_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
RandomResize
(
...
...
test/test_prototype_transforms_consistency.py
View file @
0e0a5dc7
...
...
@@ -88,6 +88,9 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
((
32
,
29
)),
ArgsKwargs
((
31
,
28
),
interpolation
=
prototype_transforms
.
InterpolationMode
.
NEAREST
),
ArgsKwargs
((
33
,
26
),
interpolation
=
prototype_transforms
.
InterpolationMode
.
BICUBIC
),
ArgsKwargs
((
30
,
27
),
interpolation
=
PIL
.
Image
.
NEAREST
),
ArgsKwargs
((
35
,
29
),
interpolation
=
PIL
.
Image
.
BILINEAR
),
ArgsKwargs
((
34
,
25
),
interpolation
=
PIL
.
Image
.
BICUBIC
),
NotScriptableArgsKwargs
(
31
,
max_size
=
32
),
ArgsKwargs
([
31
],
max_size
=
32
),
NotScriptableArgsKwargs
(
30
,
max_size
=
100
),
...
...
@@ -305,6 +308,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
(
25
,
ratio
=
(
0.5
,
1.5
)),
ArgsKwargs
((
31
,
28
),
interpolation
=
prototype_transforms
.
InterpolationMode
.
NEAREST
),
ArgsKwargs
((
33
,
26
),
interpolation
=
prototype_transforms
.
InterpolationMode
.
BICUBIC
),
ArgsKwargs
((
31
,
28
),
interpolation
=
PIL
.
Image
.
NEAREST
),
ArgsKwargs
((
33
,
26
),
interpolation
=
PIL
.
Image
.
BICUBIC
),
ArgsKwargs
((
29
,
32
),
antialias
=
False
),
ArgsKwargs
((
28
,
31
),
antialias
=
True
),
],
...
...
@@ -352,6 +357,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
(
sigma
=
(
2.5
,
3.9
)),
ArgsKwargs
(
interpolation
=
prototype_transforms
.
InterpolationMode
.
NEAREST
),
ArgsKwargs
(
interpolation
=
prototype_transforms
.
InterpolationMode
.
BICUBIC
),
ArgsKwargs
(
interpolation
=
PIL
.
Image
.
NEAREST
),
ArgsKwargs
(
interpolation
=
PIL
.
Image
.
BICUBIC
),
ArgsKwargs
(
fill
=
1
),
],
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
...
...
@@ -386,6 +393,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
(
degrees
=
0.0
,
shear
=
(
4
,
5
,
4
,
13
)),
ArgsKwargs
(
degrees
=
(
-
20.0
,
10.0
),
translate
=
(
0.4
,
0.6
),
scale
=
(
0.3
,
0.8
),
shear
=
(
4
,
5
,
4
,
13
)),
ArgsKwargs
(
degrees
=
30.0
,
interpolation
=
prototype_transforms
.
InterpolationMode
.
NEAREST
),
ArgsKwargs
(
degrees
=
30.0
,
interpolation
=
PIL
.
Image
.
NEAREST
),
ArgsKwargs
(
degrees
=
30.0
,
fill
=
1
),
ArgsKwargs
(
degrees
=
30.0
,
fill
=
(
2
,
3
,
4
)),
ArgsKwargs
(
degrees
=
30.0
,
center
=
(
0
,
0
)),
...
...
@@ -420,6 +428,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
(
p
=
1
),
ArgsKwargs
(
p
=
1
,
distortion_scale
=
0.3
),
ArgsKwargs
(
p
=
1
,
distortion_scale
=
0.2
,
interpolation
=
prototype_transforms
.
InterpolationMode
.
NEAREST
),
ArgsKwargs
(
p
=
1
,
distortion_scale
=
0.2
,
interpolation
=
PIL
.
Image
.
NEAREST
),
ArgsKwargs
(
p
=
1
,
distortion_scale
=
0.1
,
fill
=
1
),
ArgsKwargs
(
p
=
1
,
distortion_scale
=
0.4
,
fill
=
(
1
,
2
,
3
)),
],
...
...
@@ -432,6 +441,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
(
degrees
=
30.0
),
ArgsKwargs
(
degrees
=
(
-
20.0
,
10.0
)),
ArgsKwargs
(
degrees
=
30.0
,
interpolation
=
prototype_transforms
.
InterpolationMode
.
BILINEAR
),
ArgsKwargs
(
degrees
=
30.0
,
interpolation
=
PIL
.
Image
.
BILINEAR
),
ArgsKwargs
(
degrees
=
30.0
,
expand
=
True
),
ArgsKwargs
(
degrees
=
30.0
,
center
=
(
0
,
0
)),
ArgsKwargs
(
degrees
=
30.0
,
fill
=
1
),
...
...
@@ -851,7 +861,11 @@ class TestAATransforms:
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
prototype_transforms
.
InterpolationMode
.
NEAREST
,
prototype_transforms
.
InterpolationMode
.
BILINEAR
],
[
prototype_transforms
.
InterpolationMode
.
NEAREST
,
prototype_transforms
.
InterpolationMode
.
BILINEAR
,
PIL
.
Image
.
NEAREST
,
],
)
def
test_randaug
(
self
,
inpt
,
interpolation
,
mocker
):
t_ref
=
legacy_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
)
...
...
@@ -889,7 +903,11 @@ class TestAATransforms:
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
prototype_transforms
.
InterpolationMode
.
NEAREST
,
prototype_transforms
.
InterpolationMode
.
BILINEAR
],
[
prototype_transforms
.
InterpolationMode
.
NEAREST
,
prototype_transforms
.
InterpolationMode
.
BILINEAR
,
PIL
.
Image
.
NEAREST
,
],
)
def
test_trivial_aug
(
self
,
inpt
,
interpolation
,
mocker
):
t_ref
=
legacy_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
)
...
...
@@ -937,7 +955,11 @@ class TestAATransforms:
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
prototype_transforms
.
InterpolationMode
.
NEAREST
,
prototype_transforms
.
InterpolationMode
.
BILINEAR
],
[
prototype_transforms
.
InterpolationMode
.
NEAREST
,
prototype_transforms
.
InterpolationMode
.
BILINEAR
,
PIL
.
Image
.
NEAREST
,
],
)
def
test_augmix
(
self
,
inpt
,
interpolation
,
mocker
):
t_ref
=
legacy_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
)
...
...
@@ -986,7 +1008,11 @@ class TestAATransforms:
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
prototype_transforms
.
InterpolationMode
.
NEAREST
,
prototype_transforms
.
InterpolationMode
.
BILINEAR
],
[
prototype_transforms
.
InterpolationMode
.
NEAREST
,
prototype_transforms
.
InterpolationMode
.
BILINEAR
,
PIL
.
Image
.
NEAREST
,
],
)
def
test_aa
(
self
,
inpt
,
interpolation
):
aa_policy
=
legacy_transforms
.
AutoAugmentPolicy
(
"imagenet"
)
...
...
@@ -1264,13 +1290,13 @@ class TestRefSegTransforms:
(
legacy_F
.
convert_image_dtype
,
{}),
(
legacy_F
.
to_pil_image
,
{}),
(
legacy_F
.
normalize
,
{}),
(
legacy_F
.
resize
,
{}),
(
legacy_F
.
resize
,
{
"interpolation"
}),
(
legacy_F
.
pad
,
{
"padding"
,
"fill"
}),
(
legacy_F
.
crop
,
{}),
(
legacy_F
.
center_crop
,
{}),
(
legacy_F
.
resized_crop
,
{}),
(
legacy_F
.
resized_crop
,
{
"interpolation"
}),
(
legacy_F
.
hflip
,
{}),
(
legacy_F
.
perspective
,
{
"startpoints"
,
"endpoints"
,
"fill"
}),
(
legacy_F
.
perspective
,
{
"startpoints"
,
"endpoints"
,
"fill"
,
"interpolation"
}),
(
legacy_F
.
vflip
,
{}),
(
legacy_F
.
five_crop
,
{}),
(
legacy_F
.
ten_crop
,
{}),
...
...
@@ -1279,8 +1305,8 @@ class TestRefSegTransforms:
(
legacy_F
.
adjust_saturation
,
{}),
(
legacy_F
.
adjust_hue
,
{}),
(
legacy_F
.
adjust_gamma
,
{}),
(
legacy_F
.
rotate
,
{
"center"
,
"fill"
}),
(
legacy_F
.
affine
,
{
"angle"
,
"translate"
,
"center"
,
"fill"
}),
(
legacy_F
.
rotate
,
{
"center"
,
"fill"
,
"interpolation"
}),
(
legacy_F
.
affine
,
{
"angle"
,
"translate"
,
"center"
,
"fill"
,
"interpolation"
}),
(
legacy_F
.
to_grayscale
,
{}),
(
legacy_F
.
rgb_to_grayscale
,
{}),
(
legacy_F
.
to_tensor
,
{}),
...
...
@@ -1292,7 +1318,7 @@ class TestRefSegTransforms:
(
legacy_F
.
adjust_sharpness
,
{}),
(
legacy_F
.
autocontrast
,
{}),
(
legacy_F
.
equalize
,
{}),
(
legacy_F
.
elastic_transform
,
{
"fill"
}),
(
legacy_F
.
elastic_transform
,
{
"fill"
,
"interpolation"
}),
],
)
def
test_dispatcher_signature_consistency
(
legacy_dispatcher
,
name_only_params
):
...
...
torchvision/prototype/datapoints/_bounding_box.py
View file @
0e0a5dc7
...
...
@@ -76,7 +76,7 @@ class BoundingBox(Datapoint):
def
resize
(
# type: ignore[override]
self
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
BoundingBox
:
...
...
@@ -107,7 +107,7 @@ class BoundingBox(Datapoint):
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
BoundingBox
:
output
,
spatial_size
=
self
.
_F
.
resized_crop_bounding_box
(
...
...
@@ -133,7 +133,7 @@ class BoundingBox(Datapoint):
def
rotate
(
self
,
angle
:
float
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
FillTypeJIT
=
None
,
...
...
@@ -154,7 +154,7 @@ class BoundingBox(Datapoint):
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
BoundingBox
:
...
...
@@ -174,7 +174,7 @@ class BoundingBox(Datapoint):
self
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
BoundingBox
:
...
...
@@ -191,7 +191,7 @@ class BoundingBox(Datapoint):
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
FillTypeJIT
=
None
,
)
->
BoundingBox
:
output
=
self
.
_F
.
elastic_bounding_box
(
...
...
torchvision/prototype/datapoints/_datapoint.py
View file @
0e0a5dc7
...
...
@@ -143,7 +143,7 @@ class Datapoint(torch.Tensor):
def
resize
(
# type: ignore[override]
self
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Datapoint
:
...
...
@@ -162,7 +162,7 @@ class Datapoint(torch.Tensor):
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Datapoint
:
return
self
...
...
@@ -178,7 +178,7 @@ class Datapoint(torch.Tensor):
def
rotate
(
self
,
angle
:
float
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
FillTypeJIT
=
None
,
...
...
@@ -191,7 +191,7 @@ class Datapoint(torch.Tensor):
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
Datapoint
:
...
...
@@ -201,7 +201,7 @@ class Datapoint(torch.Tensor):
self
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
Datapoint
:
...
...
@@ -210,7 +210,7 @@ class Datapoint(torch.Tensor):
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
FillTypeJIT
=
None
,
)
->
Datapoint
:
return
self
...
...
torchvision/prototype/datapoints/_image.py
View file @
0e0a5dc7
...
...
@@ -62,7 +62,7 @@ class Image(Datapoint):
def
resize
(
# type: ignore[override]
self
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Image
:
...
...
@@ -86,7 +86,7 @@ class Image(Datapoint):
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Image
:
output
=
self
.
_F
.
resized_crop_image_tensor
(
...
...
@@ -113,7 +113,7 @@ class Image(Datapoint):
def
rotate
(
self
,
angle
:
float
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
FillTypeJIT
=
None
,
...
...
@@ -129,7 +129,7 @@ class Image(Datapoint):
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
Image
:
...
...
@@ -149,7 +149,7 @@ class Image(Datapoint):
self
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
Image
:
...
...
@@ -166,7 +166,7 @@ class Image(Datapoint):
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
FillTypeJIT
=
None
,
)
->
Image
:
output
=
self
.
_F
.
elastic_image_tensor
(
...
...
torchvision/prototype/datapoints/_mask.py
View file @
0e0a5dc7
...
...
@@ -53,7 +53,7 @@ class Mask(Datapoint):
def
resize
(
# type: ignore[override]
self
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Mask
:
...
...
@@ -75,7 +75,7 @@ class Mask(Datapoint):
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Mask
:
output
=
self
.
_F
.
resized_crop_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
top
,
left
,
height
,
width
,
size
=
size
)
...
...
@@ -93,7 +93,7 @@ class Mask(Datapoint):
def
rotate
(
self
,
angle
:
float
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
FillTypeJIT
=
None
,
...
...
@@ -107,7 +107,7 @@ class Mask(Datapoint):
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
Mask
:
...
...
@@ -126,7 +126,7 @@ class Mask(Datapoint):
self
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
Mask
:
...
...
@@ -138,7 +138,7 @@ class Mask(Datapoint):
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
FillTypeJIT
=
None
,
)
->
Mask
:
output
=
self
.
_F
.
elastic_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
displacement
,
fill
=
fill
)
...
...
torchvision/prototype/datapoints/_video.py
View file @
0e0a5dc7
...
...
@@ -57,7 +57,7 @@ class Video(Datapoint):
def
resize
(
# type: ignore[override]
self
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Video
:
...
...
@@ -85,7 +85,7 @@ class Video(Datapoint):
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Video
:
output
=
self
.
_F
.
resized_crop_video
(
...
...
@@ -112,7 +112,7 @@ class Video(Datapoint):
def
rotate
(
self
,
angle
:
float
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
FillTypeJIT
=
None
,
...
...
@@ -128,7 +128,7 @@ class Video(Datapoint):
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
Video
:
...
...
@@ -148,7 +148,7 @@ class Video(Datapoint):
self
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
Video
:
...
...
@@ -165,7 +165,7 @@ class Video(Datapoint):
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
FillTypeJIT
=
None
,
)
->
Video
:
output
=
self
.
_F
.
elastic_video
(
...
...
torchvision/prototype/transforms/_augment.py
View file @
0e0a5dc7
...
...
@@ -10,6 +10,7 @@ from torchvision import transforms as _transforms
from
torchvision.ops
import
masks_to_boxes
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.prototype.transforms.functional._geometry
import
_check_interpolation
from
._transform
import
_RandomApplyTransform
from
.utils
import
has_any
,
is_simple_tensor
,
query_chw
,
query_spatial_size
...
...
@@ -203,11 +204,11 @@ class SimpleCopyPaste(Transform):
def
__init__
(
self
,
blending
:
bool
=
True
,
resize_interpolation
:
InterpolationMode
=
F
.
InterpolationMode
.
BILINEAR
,
resize_interpolation
:
Union
[
int
,
InterpolationMode
]
=
F
.
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
bool
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
resize_interpolation
=
resize_interpolation
self
.
resize_interpolation
=
_check_interpolation
(
resize_interpolation
)
self
.
blending
=
blending
self
.
antialias
=
antialias
...
...
torchvision/prototype/transforms/_auto_augment.py
View file @
0e0a5dc7
...
...
@@ -8,6 +8,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from
torchvision
import
transforms
as
_transforms
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
AutoAugmentPolicy
,
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.prototype.transforms.functional._geometry
import
_check_interpolation
from
torchvision.prototype.transforms.functional._meta
import
get_spatial_size
from
torchvision.transforms
import
functional_tensor
as
_FT
...
...
@@ -19,11 +20,11 @@ class _AutoAugmentBase(Transform):
def
__init__
(
self
,
*
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
FillType
,
Dict
[
Type
,
datapoints
.
FillType
]]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
def
_get_random_item
(
self
,
dct
:
Dict
[
str
,
Tuple
[
Callable
,
bool
]])
->
Tuple
[
str
,
Tuple
[
Callable
,
bool
]]:
...
...
@@ -79,7 +80,7 @@ class _AutoAugmentBase(Transform):
image
:
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
],
transform_id
:
str
,
magnitude
:
float
,
interpolation
:
InterpolationMode
,
interpolation
:
Union
[
InterpolationMode
,
int
],
fill
:
Dict
[
Type
,
datapoints
.
FillTypeJIT
],
)
->
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]:
fill_
=
fill
[
type
(
image
)]
...
...
@@ -193,7 +194,7 @@ class AutoAugment(_AutoAugmentBase):
def
__init__
(
self
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
FillType
,
Dict
[
Type
,
datapoints
.
FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
...
...
@@ -350,7 +351,7 @@ class RandAugment(_AutoAugmentBase):
num_ops
:
int
=
2
,
magnitude
:
int
=
9
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
FillType
,
Dict
[
Type
,
datapoints
.
FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
...
...
@@ -403,7 +404,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
def
__init__
(
self
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
FillType
,
Dict
[
Type
,
datapoints
.
FillType
]]
=
None
,
):
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
...
...
@@ -461,7 +462,7 @@ class AugMix(_AutoAugmentBase):
chain_depth
:
int
=
-
1
,
alpha
:
float
=
1.0
,
all_ops
:
bool
=
True
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
FillType
,
Dict
[
Type
,
datapoints
.
FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
...
...
torchvision/prototype/transforms/_geometry.py
View file @
0e0a5dc7
...
...
@@ -10,6 +10,7 @@ from torchvision import transforms as _transforms
from
torchvision.ops.boxes
import
box_iou
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.prototype.transforms.functional._geometry
import
_check_interpolation
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
._transform
import
_RandomApplyTransform
...
...
@@ -45,7 +46,7 @@ class Resize(Transform):
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
None
:
...
...
@@ -61,7 +62,7 @@ class Resize(Transform):
)
self
.
size
=
size
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
max_size
=
max_size
self
.
antialias
=
antialias
...
...
@@ -94,7 +95,7 @@ class RandomResizedCrop(Transform):
size
:
Union
[
int
,
Sequence
[
int
]],
scale
:
Tuple
[
float
,
float
]
=
(
0.08
,
1.0
),
ratio
:
Tuple
[
float
,
float
]
=
(
3.0
/
4.0
,
4.0
/
3.0
),
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -111,7 +112,7 @@ class RandomResizedCrop(Transform):
self
.
scale
=
scale
self
.
ratio
=
ratio
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
antialias
=
antialias
self
.
_log_ratio
=
torch
.
log
(
torch
.
tensor
(
self
.
ratio
))
...
...
@@ -317,14 +318,14 @@ class RandomRotation(Transform):
def
__init__
(
self
,
degrees
:
Union
[
numbers
.
Number
,
Sequence
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
fill
:
Union
[
datapoints
.
FillType
,
Dict
[
Type
,
datapoints
.
FillType
]]
=
0
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
degrees
=
_setup_angle
(
degrees
,
name
=
"degrees"
,
req_sizes
=
(
2
,))
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
expand
=
expand
self
.
fill
=
_setup_fill_arg
(
fill
)
...
...
@@ -359,7 +360,7 @@ class RandomAffine(Transform):
translate
:
Optional
[
Sequence
[
float
]]
=
None
,
scale
:
Optional
[
Sequence
[
float
]]
=
None
,
shear
:
Optional
[
Union
[
int
,
float
,
Sequence
[
float
]]]
=
None
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
FillType
,
Dict
[
Type
,
datapoints
.
FillType
]]
=
0
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
...
...
@@ -383,7 +384,7 @@ class RandomAffine(Transform):
else
:
self
.
shear
=
shear
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
if
center
is
not
None
:
...
...
@@ -546,7 +547,7 @@ class RandomPerspective(_RandomApplyTransform):
self
,
distortion_scale
:
float
=
0.5
,
fill
:
Union
[
datapoints
.
FillType
,
Dict
[
Type
,
datapoints
.
FillType
]]
=
0
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
p
:
float
=
0.5
,
)
->
None
:
super
().
__init__
(
p
=
p
)
...
...
@@ -555,7 +556,7 @@ class RandomPerspective(_RandomApplyTransform):
raise
ValueError
(
"Argument distortion_scale value should be between 0 and 1"
)
self
.
distortion_scale
=
distortion_scale
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
...
...
@@ -608,13 +609,13 @@ class ElasticTransform(Transform):
alpha
:
Union
[
float
,
Sequence
[
float
]]
=
50.0
,
sigma
:
Union
[
float
,
Sequence
[
float
]]
=
5.0
,
fill
:
Union
[
datapoints
.
FillType
,
Dict
[
Type
,
datapoints
.
FillType
]]
=
0
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
)
->
None
:
super
().
__init__
()
self
.
alpha
=
_setup_float_or_seq
(
alpha
,
"alpha"
,
2
)
self
.
sigma
=
_setup_float_or_seq
(
sigma
,
"sigma"
,
2
)
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
...
...
@@ -760,13 +761,13 @@ class ScaleJitter(Transform):
self
,
target_size
:
Tuple
[
int
,
int
],
scale_range
:
Tuple
[
float
,
float
]
=
(
0.1
,
2.0
),
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
):
super
().
__init__
()
self
.
target_size
=
target_size
self
.
scale_range
=
scale_range
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
antialias
=
antialias
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
...
...
@@ -788,13 +789,13 @@ class RandomShortestSize(Transform):
self
,
min_size
:
Union
[
List
[
int
],
Tuple
[
int
],
int
],
max_size
:
Optional
[
int
]
=
None
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
):
super
().
__init__
()
self
.
min_size
=
[
min_size
]
if
isinstance
(
min_size
,
int
)
else
list
(
min_size
)
self
.
max_size
=
max_size
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
antialias
=
antialias
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
...
...
@@ -935,13 +936,13 @@ class RandomResize(Transform):
self
,
min_size
:
int
,
max_size
:
int
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
None
:
super
().
__init__
()
self
.
min_size
=
min_size
self
.
max_size
=
max_size
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
antialias
=
antialias
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
...
...
torchvision/prototype/transforms/_presets.py
View file @
0e0a5dc7
...
...
@@ -9,6 +9,8 @@ import PIL.Image
import
torch
from
torch
import
Tensor
from
torchvision.prototype.transforms.functional._geometry
import
_check_interpolation
from
.
import
functional
as
F
,
InterpolationMode
__all__
=
[
"StereoMatching"
]
...
...
@@ -22,7 +24,7 @@ class StereoMatching(torch.nn.Module):
resize_size
:
Optional
[
Tuple
[
int
,
...]],
mean
:
Tuple
[
float
,
...]
=
(
0.5
,
0.5
,
0.5
),
std
:
Tuple
[
float
,
...]
=
(
0.5
,
0.5
,
0.5
),
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
)
->
None
:
super
().
__init__
()
...
...
@@ -36,7 +38,7 @@ class StereoMatching(torch.nn.Module):
self
.
mean
=
list
(
mean
)
self
.
std
=
list
(
std
)
self
.
interpolation
=
interpolation
self
.
interpolation
=
_check_
interpolation
(
interpolation
)
self
.
use_gray_scale
=
use_gray_scale
def
forward
(
self
,
left_image
:
Tensor
,
right_image
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
]:
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
0e0a5dc7
...
...
@@ -13,6 +13,7 @@ from torchvision.transforms.functional import (
_check_antialias
,
_compute_resized_output_size
as
__compute_resized_output_size
,
_get_perspective_coeffs
,
_interpolation_modes_from_int
,
InterpolationMode
,
pil_modes_mapping
,
pil_to_tensor
,
...
...
@@ -27,6 +28,17 @@ from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_
from
._utils
import
is_simple_tensor
def
_check_interpolation
(
interpolation
:
Union
[
InterpolationMode
,
int
])
->
InterpolationMode
:
if
isinstance
(
interpolation
,
int
):
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
elif
not
isinstance
(
interpolation
,
InterpolationMode
):
raise
ValueError
(
f
"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
f
"but got
{
interpolation
}
."
)
return
interpolation
def
horizontal_flip_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
1
)
...
...
@@ -142,10 +154,11 @@ def _compute_resized_output_size(
def
resize_image_tensor
(
image
:
torch
.
Tensor
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
antialias
=
_check_antialias
(
img
=
image
,
antialias
=
antialias
,
interpolation
=
interpolation
)
assert
not
isinstance
(
antialias
,
str
)
antialias
=
False
if
antialias
is
None
else
antialias
...
...
@@ -189,9 +202,10 @@ def resize_image_tensor(
def
resize_image_pil
(
image
:
PIL
.
Image
.
Image
,
size
:
Union
[
Sequence
[
int
],
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
)
->
PIL
.
Image
.
Image
:
interpolation
=
_check_interpolation
(
interpolation
)
size
=
_compute_resized_output_size
(
image
.
size
[::
-
1
],
size
=
size
,
max_size
=
max_size
)
# type: ignore[arg-type]
return
_FP
.
resize
(
image
,
size
,
interpolation
=
pil_modes_mapping
[
interpolation
])
...
...
@@ -228,7 +242,7 @@ def resize_bounding_box(
def
resize_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
torch
.
Tensor
:
...
...
@@ -238,7 +252,7 @@ def resize_video(
def
resize
(
inpt
:
datapoints
.
InputTypeJIT
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
InputTypeJIT
:
...
...
@@ -513,10 +527,12 @@ def affine_image_tensor(
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
if
image
.
numel
()
==
0
:
return
image
...
...
@@ -563,10 +579,11 @@ def affine_image_pil(
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
PIL
.
Image
.
Image
:
interpolation
=
_check_interpolation
(
interpolation
)
angle
,
translate
,
shear
,
center
=
_affine_parse_args
(
angle
,
translate
,
scale
,
shear
,
interpolation
,
center
)
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
...
...
@@ -731,7 +748,7 @@ def affine_video(
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -753,7 +770,7 @@ def affine(
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
...
...
@@ -797,11 +814,13 @@ def affine(
def
rotate_image_tensor
(
image
:
torch
.
Tensor
,
angle
:
float
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
shape
=
image
.
shape
num_channels
,
height
,
width
=
shape
[
-
3
:]
...
...
@@ -840,11 +859,13 @@ def rotate_image_tensor(
def
rotate_image_pil
(
image
:
PIL
.
Image
.
Image
,
angle
:
float
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
PIL
.
Image
.
Image
:
interpolation
=
_check_interpolation
(
interpolation
)
if
center
is
not
None
and
expand
:
warnings
.
warn
(
"The provided center argument has no effect on the result if expand is True"
)
center
=
None
...
...
@@ -910,7 +931,7 @@ def rotate_mask(
def
rotate_video
(
video
:
torch
.
Tensor
,
angle
:
float
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
...
...
@@ -921,7 +942,7 @@ def rotate_video(
def
rotate
(
inpt
:
datapoints
.
InputTypeJIT
,
angle
:
float
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
...
...
@@ -1281,11 +1302,13 @@ def perspective_image_tensor(
image
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
interpolation
=
_check_interpolation
(
interpolation
)
if
image
.
numel
()
==
0
:
return
image
...
...
@@ -1326,11 +1349,12 @@ def perspective_image_pil(
image
:
PIL
.
Image
.
Image
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BICUBIC
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BICUBIC
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
PIL
.
Image
.
Image
:
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
interpolation
=
_check_interpolation
(
interpolation
)
return
_FP
.
perspective
(
image
,
perspective_coeffs
,
interpolation
=
pil_modes_mapping
[
interpolation
],
fill
=
fill
)
...
...
@@ -1455,7 +1479,7 @@ def perspective_video(
video
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -1468,7 +1492,7 @@ def perspective(
inpt
:
datapoints
.
InputTypeJIT
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
...
...
@@ -1496,9 +1520,11 @@ def perspective(
def
elastic_image_tensor
(
image
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
if
image
.
numel
()
==
0
:
return
image
...
...
@@ -1537,7 +1563,7 @@ def elastic_image_tensor(
def
elastic_image_pil
(
image
:
PIL
.
Image
.
Image
,
displacement
:
torch
.
Tensor
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
PIL
.
Image
.
Image
:
t_img
=
pil_to_tensor
(
image
)
...
...
@@ -1630,7 +1656,7 @@ def elastic_mask(
def
elastic_video
(
video
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
return
elastic_image_tensor
(
video
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
...
...
@@ -1639,7 +1665,7 @@ def elastic_video(
def
elastic
(
inpt
:
datapoints
.
InputTypeJIT
,
displacement
:
torch
.
Tensor
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
...
...
@@ -1778,7 +1804,7 @@ def resized_crop_image_tensor(
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
torch
.
Tensor
:
image
=
crop_image_tensor
(
image
,
top
,
left
,
height
,
width
)
...
...
@@ -1793,7 +1819,7 @@ def resized_crop_image_pil(
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
)
->
PIL
.
Image
.
Image
:
image
=
crop_image_pil
(
image
,
top
,
left
,
height
,
width
)
return
resize_image_pil
(
image
,
size
,
interpolation
=
interpolation
)
...
...
@@ -1831,7 +1857,7 @@ def resized_crop_video(
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
torch
.
Tensor
:
return
resized_crop_image_tensor
(
...
...
@@ -1846,7 +1872,7 @@ def resized_crop(
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment