Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
641fdd9f
Unverified
Commit
641fdd9f
authored
Aug 09, 2023
by
Philip Meier
Committed by
GitHub
Aug 09, 2023
Browse files
remove custom types defintions from datapoints module (#7814)
parent
6b020798
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
137 additions
and
177 deletions
+137
-177
torchvision/datapoints/__init__.py
torchvision/datapoints/__init__.py
+3
-3
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+1
-8
torchvision/datapoints/_image.py
torchvision/datapoints/_image.py
+0
-6
torchvision/datapoints/_video.py
torchvision/datapoints/_video.py
+0
-6
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+5
-5
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+2
-2
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+2
-6
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+14
-10
torchvision/transforms/v2/_color.py
torchvision/transforms/v2/_color.py
+1
-3
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+8
-10
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+1
-3
torchvision/transforms/v2/_temporal.py
torchvision/transforms/v2/_temporal.py
+1
-2
torchvision/transforms/v2/_utils.py
torchvision/transforms/v2/_utils.py
+2
-3
torchvision/transforms/v2/functional/_augment.py
torchvision/transforms/v2/functional/_augment.py
+2
-4
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+14
-16
torchvision/transforms/v2/functional/_deprecated.py
torchvision/transforms/v2/functional/_deprecated.py
+2
-3
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+64
-66
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+9
-9
torchvision/transforms/v2/functional/_misc.py
torchvision/transforms/v2/functional/_misc.py
+5
-11
torchvision/transforms/v2/functional/_temporal.py
torchvision/transforms/v2/functional/_temporal.py
+1
-1
No files found.
torchvision/datapoints/__init__.py
View file @
641fdd9f
from
torchvision
import
_BETA_TRANSFORMS_WARNING
,
_WARN_ABOUT_BETA_TRANSFORMS
from
torchvision
import
_BETA_TRANSFORMS_WARNING
,
_WARN_ABOUT_BETA_TRANSFORMS
from
._bounding_box
import
BoundingBoxes
,
BoundingBoxFormat
from
._bounding_box
import
BoundingBoxes
,
BoundingBoxFormat
from
._datapoint
import
_FillType
,
_FillTypeJIT
,
_InputType
,
_InputTypeJIT
,
Datapoint
from
._datapoint
import
Datapoint
from
._image
import
_ImageType
,
_ImageTypeJIT
,
_TensorImageType
,
_TensorImageTypeJIT
,
Image
from
._image
import
Image
from
._mask
import
Mask
from
._mask
import
Mask
from
._video
import
_TensorVideoType
,
_TensorVideoTypeJIT
,
_VideoType
,
_VideoTypeJIT
,
Video
from
._video
import
Video
if
_WARN_ABOUT_BETA_TRANSFORMS
:
if
_WARN_ABOUT_BETA_TRANSFORMS
:
import
warnings
import
warnings
...
...
torchvision/datapoints/_datapoint.py
View file @
641fdd9f
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
import
PIL.Image
import
torch
import
torch
from
torch._C
import
DisableTorchFunctionSubclass
from
torch._C
import
DisableTorchFunctionSubclass
from
torch.types
import
_device
,
_dtype
,
_size
from
torch.types
import
_device
,
_dtype
,
_size
D
=
TypeVar
(
"D"
,
bound
=
"Datapoint"
)
D
=
TypeVar
(
"D"
,
bound
=
"Datapoint"
)
_FillType
=
Union
[
int
,
float
,
Sequence
[
int
],
Sequence
[
float
],
None
]
_FillTypeJIT
=
Optional
[
List
[
float
]]
class
Datapoint
(
torch
.
Tensor
):
class
Datapoint
(
torch
.
Tensor
):
...
@@ -132,7 +129,3 @@ class Datapoint(torch.Tensor):
...
@@ -132,7 +129,3 @@ class Datapoint(torch.Tensor):
# `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBoxes.clone()`.
# `BoundingBoxes.clone()`.
return
self
.
detach
().
clone
().
requires_grad_
(
self
.
requires_grad
)
# type: ignore[return-value]
return
self
.
detach
().
clone
().
requires_grad_
(
self
.
requires_grad
)
# type: ignore[return-value]
_InputType
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
Datapoint
]
_InputTypeJIT
=
torch
.
Tensor
torchvision/datapoints/_image.py
View file @
641fdd9f
...
@@ -45,9 +45,3 @@ class Image(Datapoint):
...
@@ -45,9 +45,3 @@ class Image(Datapoint):
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
()
return
self
.
_make_repr
()
_ImageType
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
Image
]
_ImageTypeJIT
=
torch
.
Tensor
_TensorImageType
=
Union
[
torch
.
Tensor
,
Image
]
_TensorImageTypeJIT
=
torch
.
Tensor
torchvision/datapoints/_video.py
View file @
641fdd9f
...
@@ -35,9 +35,3 @@ class Video(Datapoint):
...
@@ -35,9 +35,3 @@ class Video(Datapoint):
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
()
return
self
.
_make_repr
()
_VideoType
=
Union
[
torch
.
Tensor
,
Video
]
_VideoTypeJIT
=
torch
.
Tensor
_TensorVideoType
=
Union
[
torch
.
Tensor
,
Video
]
_TensorVideoTypeJIT
=
torch
.
Tensor
torchvision/prototype/transforms/_augment.py
View file @
641fdd9f
...
@@ -26,15 +26,15 @@ class SimpleCopyPaste(Transform):
...
@@ -26,15 +26,15 @@ class SimpleCopyPaste(Transform):
def
_copy_paste
(
def
_copy_paste
(
self
,
self
,
image
:
datapoints
.
_Tensor
Image
Type
,
image
:
Union
[
torch
.
Tensor
,
datapoints
.
Image
]
,
target
:
Dict
[
str
,
Any
],
target
:
Dict
[
str
,
Any
],
paste_image
:
datapoints
.
_Tensor
Image
Type
,
paste_image
:
Union
[
torch
.
Tensor
,
datapoints
.
Image
]
,
paste_target
:
Dict
[
str
,
Any
],
paste_target
:
Dict
[
str
,
Any
],
random_selection
:
torch
.
Tensor
,
random_selection
:
torch
.
Tensor
,
blending
:
bool
,
blending
:
bool
,
resize_interpolation
:
F
.
InterpolationMode
,
resize_interpolation
:
F
.
InterpolationMode
,
antialias
:
Optional
[
bool
],
antialias
:
Optional
[
bool
],
)
->
Tuple
[
datapoints
.
_TensorImageType
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
Any
]]:
paste_masks
=
paste_target
[
"masks"
].
wrap_like
(
paste_target
[
"masks"
],
paste_target
[
"masks"
][
random_selection
])
paste_masks
=
paste_target
[
"masks"
].
wrap_like
(
paste_target
[
"masks"
],
paste_target
[
"masks"
][
random_selection
])
paste_boxes
=
paste_target
[
"boxes"
].
wrap_like
(
paste_target
[
"boxes"
],
paste_target
[
"boxes"
][
random_selection
])
paste_boxes
=
paste_target
[
"boxes"
].
wrap_like
(
paste_target
[
"boxes"
],
paste_target
[
"boxes"
][
random_selection
])
...
@@ -106,7 +106,7 @@ class SimpleCopyPaste(Transform):
...
@@ -106,7 +106,7 @@ class SimpleCopyPaste(Transform):
def
_extract_image_targets
(
def
_extract_image_targets
(
self
,
flat_sample
:
List
[
Any
]
self
,
flat_sample
:
List
[
Any
]
)
->
Tuple
[
List
[
datapoints
.
_Tensor
Image
Type
],
List
[
Dict
[
str
,
Any
]]]:
)
->
Tuple
[
List
[
Union
[
torch
.
Tensor
,
datapoints
.
Image
]
],
List
[
Dict
[
str
,
Any
]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
...
@@ -137,7 +137,7 @@ class SimpleCopyPaste(Transform):
...
@@ -137,7 +137,7 @@ class SimpleCopyPaste(Transform):
def
_insert_outputs
(
def
_insert_outputs
(
self
,
self
,
flat_sample
:
List
[
Any
],
flat_sample
:
List
[
Any
],
output_images
:
List
[
datapoints
.
_TensorImageType
],
output_images
:
List
[
torch
.
Tensor
],
output_targets
:
List
[
Dict
[
str
,
Any
]],
output_targets
:
List
[
Dict
[
str
,
Any
]],
)
->
None
:
)
->
None
:
c0
,
c1
,
c2
,
c3
=
0
,
0
,
0
,
0
c0
,
c1
,
c2
,
c3
=
0
,
0
,
0
,
0
...
...
torchvision/prototype/transforms/_geometry.py
View file @
641fdd9f
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2._utils
import
_FillType
,
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2.utils
import
get_bounding_boxes
,
has_any
,
is_simple_tensor
,
query_size
from
torchvision.transforms.v2.utils
import
get_bounding_boxes
,
has_any
,
is_simple_tensor
,
query_size
...
@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
...
@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def
__init__
(
def
__init__
(
self
,
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
size
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
padding_mode
:
str
=
"constant"
,
padding_mode
:
str
=
"constant"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
...
torchvision/prototype/transforms/_misc.py
View file @
641fdd9f
...
@@ -39,9 +39,7 @@ class PermuteDimensions(Transform):
...
@@ -39,9 +39,7 @@ class PermuteDimensions(Transform):
)
)
self
.
dims
=
dims
self
.
dims
=
dims
def
_transform
(
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
self
,
inpt
:
Union
[
datapoints
.
_TensorImageType
,
datapoints
.
_TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
torch
.
Tensor
:
dims
=
self
.
dims
[
type
(
inpt
)]
dims
=
self
.
dims
[
type
(
inpt
)]
if
dims
is
None
:
if
dims
is
None
:
return
inpt
.
as_subclass
(
torch
.
Tensor
)
return
inpt
.
as_subclass
(
torch
.
Tensor
)
...
@@ -63,9 +61,7 @@ class TransposeDimensions(Transform):
...
@@ -63,9 +61,7 @@ class TransposeDimensions(Transform):
)
)
self
.
dims
=
dims
self
.
dims
=
dims
def
_transform
(
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
self
,
inpt
:
Union
[
datapoints
.
_TensorImageType
,
datapoints
.
_TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
torch
.
Tensor
:
dims
=
self
.
dims
[
type
(
inpt
)]
dims
=
self
.
dims
[
type
(
inpt
)]
if
dims
is
None
:
if
dims
is
None
:
return
inpt
.
as_subclass
(
torch
.
Tensor
)
return
inpt
.
as_subclass
(
torch
.
Tensor
)
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
641fdd9f
...
@@ -10,17 +10,21 @@ from torchvision.transforms import _functional_tensor as _FT
...
@@ -10,17 +10,21 @@ from torchvision.transforms import _functional_tensor as _FT
from
torchvision.transforms.v2
import
AutoAugmentPolicy
,
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2
import
AutoAugmentPolicy
,
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._meta
import
get_size
from
torchvision.transforms.v2.functional._meta
import
get_size
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
from
._utils
import
_get_fill
,
_setup_fill_arg
from
._utils
import
_get_fill
,
_setup_fill_arg
from
.utils
import
check_type
,
is_simple_tensor
from
.utils
import
check_type
,
is_simple_tensor
ImageOrVideo
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
Video
]
class
_AutoAugmentBase
(
Transform
):
class
_AutoAugmentBase
(
Transform
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
...
@@ -35,7 +39,7 @@ class _AutoAugmentBase(Transform):
...
@@ -35,7 +39,7 @@ class _AutoAugmentBase(Transform):
self
,
self
,
inputs
:
Any
,
inputs
:
Any
,
unsupported_types
:
Tuple
[
Type
,
...]
=
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
),
unsupported_types
:
Tuple
[
Type
,
...]
=
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
),
)
->
Tuple
[
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
]:
)
->
Tuple
[
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
ImageOrVideo
]:
flat_inputs
,
spec
=
tree_flatten
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
flat_inputs
,
spec
=
tree_flatten
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
...
@@ -68,7 +72,7 @@ class _AutoAugmentBase(Transform):
...
@@ -68,7 +72,7 @@ class _AutoAugmentBase(Transform):
def
_unflatten_and_insert_image_or_video
(
def
_unflatten_and_insert_image_or_video
(
self
,
self
,
flat_inputs_with_spec
:
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
flat_inputs_with_spec
:
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
image_or_video
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
,
image_or_video
:
ImageOrVideo
,
)
->
Any
:
)
->
Any
:
flat_inputs
,
spec
,
idx
=
flat_inputs_with_spec
flat_inputs
,
spec
,
idx
=
flat_inputs_with_spec
flat_inputs
[
idx
]
=
image_or_video
flat_inputs
[
idx
]
=
image_or_video
...
@@ -76,12 +80,12 @@ class _AutoAugmentBase(Transform):
...
@@ -76,12 +80,12 @@ class _AutoAugmentBase(Transform):
def
_apply_image_or_video_transform
(
def
_apply_image_or_video_transform
(
self
,
self
,
image
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
,
image
:
ImageOrVideo
,
transform_id
:
str
,
transform_id
:
str
,
magnitude
:
float
,
magnitude
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
],
interpolation
:
Union
[
InterpolationMode
,
int
],
fill
:
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillTypeJIT
],
fill
:
Dict
[
Union
[
Type
,
str
],
_FillTypeJIT
],
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
:
)
->
ImageOrVideo
:
fill_
=
_get_fill
(
fill
,
type
(
image
))
fill_
=
_get_fill
(
fill
,
type
(
image
))
if
transform_id
==
"Identity"
:
if
transform_id
==
"Identity"
:
...
@@ -214,7 +218,7 @@ class AutoAugment(_AutoAugmentBase):
...
@@ -214,7 +218,7 @@ class AutoAugment(_AutoAugmentBase):
self
,
self
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
policy
=
policy
self
.
policy
=
policy
...
@@ -394,7 +398,7 @@ class RandAugment(_AutoAugmentBase):
...
@@ -394,7 +398,7 @@ class RandAugment(_AutoAugmentBase):
magnitude
:
int
=
9
,
magnitude
:
int
=
9
,
num_magnitude_bins
:
int
=
31
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
num_ops
=
num_ops
self
.
num_ops
=
num_ops
...
@@ -467,7 +471,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
...
@@ -467,7 +471,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self
,
self
,
num_magnitude_bins
:
int
=
31
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
):
):
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
num_magnitude_bins
=
num_magnitude_bins
self
.
num_magnitude_bins
=
num_magnitude_bins
...
@@ -550,7 +554,7 @@ class AugMix(_AutoAugmentBase):
...
@@ -550,7 +554,7 @@ class AugMix(_AutoAugmentBase):
alpha
:
float
=
1.0
,
alpha
:
float
=
1.0
,
all_ops
:
bool
=
True
,
all_ops
:
bool
=
True
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
_PARAMETER_MAX
=
10
self
.
_PARAMETER_MAX
=
10
...
...
torchvision/transforms/v2/_color.py
View file @
641fdd9f
...
@@ -261,9 +261,7 @@ class RandomPhotometricDistort(Transform):
...
@@ -261,9 +261,7 @@ class RandomPhotometricDistort(Transform):
params
[
"channel_permutation"
]
=
torch
.
randperm
(
num_channels
)
if
torch
.
rand
(
1
)
<
self
.
p
else
None
params
[
"channel_permutation"
]
=
torch
.
randperm
(
num_channels
)
if
torch
.
rand
(
1
)
<
self
.
p
else
None
return
params
return
params
def
_transform
(
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
self
,
inpt
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]:
if
params
[
"brightness_factor"
]
is
not
None
:
if
params
[
"brightness_factor"
]
is
not
None
:
inpt
=
F
.
adjust_brightness
(
inpt
,
brightness_factor
=
params
[
"brightness_factor"
])
inpt
=
F
.
adjust_brightness
(
inpt
,
brightness_factor
=
params
[
"brightness_factor"
])
if
params
[
"contrast_factor"
]
is
not
None
and
params
[
"contrast_before"
]:
if
params
[
"contrast_factor"
]
is
not
None
and
params
[
"contrast_before"
]:
...
...
torchvision/transforms/v2/_geometry.py
View file @
641fdd9f
...
@@ -11,6 +11,7 @@ from torchvision.ops.boxes import box_iou
...
@@ -11,6 +11,7 @@ from torchvision.ops.boxes import box_iou
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._utils
import
_FillType
from
._transform
import
_RandomApplyTransform
from
._transform
import
_RandomApplyTransform
from
._utils
import
(
from
._utils
import
(
...
@@ -311,9 +312,6 @@ class RandomResizedCrop(Transform):
...
@@ -311,9 +312,6 @@ class RandomResizedCrop(Transform):
)
)
ImageOrVideoTypeJIT
=
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
class
FiveCrop
(
Transform
):
class
FiveCrop
(
Transform
):
"""[BETA] Crop the image or video into four corners and the central crop.
"""[BETA] Crop the image or video into four corners and the central crop.
...
@@ -459,7 +457,7 @@ class Pad(Transform):
...
@@ -459,7 +457,7 @@ class Pad(Transform):
def
__init__
(
def
__init__
(
self
,
self
,
padding
:
Union
[
int
,
Sequence
[
int
]],
padding
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -514,7 +512,7 @@ class RandomZoomOut(_RandomApplyTransform):
...
@@ -514,7 +512,7 @@ class RandomZoomOut(_RandomApplyTransform):
def
__init__
(
def
__init__
(
self
,
self
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
side_range
:
Sequence
[
float
]
=
(
1.0
,
4.0
),
side_range
:
Sequence
[
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
,
p
:
float
=
0.5
,
)
->
None
:
)
->
None
:
...
@@ -592,7 +590,7 @@ class RandomRotation(Transform):
...
@@ -592,7 +590,7 @@ class RandomRotation(Transform):
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
degrees
=
_setup_angle
(
degrees
,
name
=
"degrees"
,
req_sizes
=
(
2
,))
self
.
degrees
=
_setup_angle
(
degrees
,
name
=
"degrees"
,
req_sizes
=
(
2
,))
...
@@ -674,7 +672,7 @@ class RandomAffine(Transform):
...
@@ -674,7 +672,7 @@ class RandomAffine(Transform):
scale
:
Optional
[
Sequence
[
float
]]
=
None
,
scale
:
Optional
[
Sequence
[
float
]]
=
None
,
shear
:
Optional
[
Union
[
int
,
float
,
Sequence
[
float
]]]
=
None
,
shear
:
Optional
[
Union
[
int
,
float
,
Sequence
[
float
]]]
=
None
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -812,7 +810,7 @@ class RandomCrop(Transform):
...
@@ -812,7 +810,7 @@ class RandomCrop(Transform):
size
:
Union
[
int
,
Sequence
[
int
]],
size
:
Union
[
int
,
Sequence
[
int
]],
padding
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
padding
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
pad_if_needed
:
bool
=
False
,
pad_if_needed
:
bool
=
False
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -931,7 +929,7 @@ class RandomPerspective(_RandomApplyTransform):
...
@@ -931,7 +929,7 @@ class RandomPerspective(_RandomApplyTransform):
distortion_scale
:
float
=
0.5
,
distortion_scale
:
float
=
0.5
,
p
:
float
=
0.5
,
p
:
float
=
0.5
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
)
->
None
:
)
->
None
:
super
().
__init__
(
p
=
p
)
super
().
__init__
(
p
=
p
)
...
@@ -1033,7 +1031,7 @@ class ElasticTransform(Transform):
...
@@ -1033,7 +1031,7 @@ class ElasticTransform(Transform):
alpha
:
Union
[
float
,
Sequence
[
float
]]
=
50.0
,
alpha
:
Union
[
float
,
Sequence
[
float
]]
=
50.0
,
sigma
:
Union
[
float
,
Sequence
[
float
]]
=
5.0
,
sigma
:
Union
[
float
,
Sequence
[
float
]]
=
5.0
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
alpha
=
_setup_float_or_seq
(
alpha
,
"alpha"
,
2
)
self
.
alpha
=
_setup_float_or_seq
(
alpha
,
"alpha"
,
2
)
...
...
torchvision/transforms/v2/_misc.py
View file @
641fdd9f
...
@@ -169,9 +169,7 @@ class Normalize(Transform):
...
@@ -169,9 +169,7 @@ class Normalize(Transform):
if
has_any
(
sample
,
PIL
.
Image
.
Image
):
if
has_any
(
sample
,
PIL
.
Image
.
Image
):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images."
)
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images."
)
def
_transform
(
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
self
,
inpt
:
Union
[
datapoints
.
_TensorImageType
,
datapoints
.
_TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Any
:
return
F
.
normalize
(
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
return
F
.
normalize
(
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
...
...
torchvision/transforms/v2/_temporal.py
View file @
641fdd9f
from
typing
import
Any
,
Dict
from
typing
import
Any
,
Dict
import
torch
import
torch
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
...
@@ -25,5 +24,5 @@ class UniformTemporalSubsample(Transform):
...
@@ -25,5 +24,5 @@ class UniformTemporalSubsample(Transform):
super
().
__init__
()
super
().
__init__
()
self
.
num_samples
=
num_samples
self
.
num_samples
=
num_samples
def
_transform
(
self
,
inpt
:
datapoints
.
_VideoType
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
_VideoType
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
uniform_temporal_subsample
(
inpt
,
self
.
num_samples
)
return
F
.
uniform_temporal_subsample
(
inpt
,
self
.
num_samples
)
torchvision/transforms/v2/_utils.py
View file @
641fdd9f
...
@@ -5,9 +5,8 @@ from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
...
@@ -5,9 +5,8 @@ from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
import
torch
import
torch
from
torchvision
import
datapoints
from
torchvision.datapoints._datapoint
import
_FillType
,
_FillTypeJIT
from
torchvision.transforms.transforms
import
_check_sequence_input
,
_setup_angle
,
_setup_size
# noqa: F401
from
torchvision.transforms.transforms
import
_check_sequence_input
,
_setup_angle
,
_setup_size
# noqa: F401
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
def
_setup_float_or_seq
(
arg
:
Union
[
float
,
Sequence
[
float
]],
name
:
str
,
req_size
:
int
=
2
)
->
Sequence
[
float
]:
def
_setup_float_or_seq
(
arg
:
Union
[
float
,
Sequence
[
float
]],
name
:
str
,
req_size
:
int
=
2
)
->
Sequence
[
float
]:
...
@@ -36,7 +35,7 @@ def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -
...
@@ -36,7 +35,7 @@ def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -
raise
TypeError
(
"Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed."
)
raise
TypeError
(
"Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed."
)
def
_convert_fill_arg
(
fill
:
datapoints
.
_FillType
)
->
datapoints
.
_FillTypeJIT
:
def
_convert_fill_arg
(
fill
:
_FillType
)
->
_FillTypeJIT
:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# So, we can't reassign fill to 0
# if fill is None:
# if fill is None:
...
...
torchvision/transforms/v2/functional/_augment.py
View file @
641fdd9f
from
typing
import
Union
import
PIL.Image
import
PIL.Image
import
torch
import
torch
...
@@ -12,14 +10,14 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
...
@@ -12,14 +10,14 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
warn_passthrough
=
True
)
@
_register_explicit_noop
(
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
warn_passthrough
=
True
)
def
erase
(
def
erase
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
,
inpt
:
torch
.
Tensor
,
i
:
int
,
i
:
int
,
j
:
int
,
j
:
int
,
h
:
int
,
h
:
int
,
w
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
...
...
torchvision/transforms/v2/functional/_color.py
View file @
641fdd9f
from
typing
import
List
,
Union
from
typing
import
List
import
PIL.Image
import
PIL.Image
import
torch
import
torch
...
@@ -16,9 +16,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
...
@@ -16,9 +16,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
)
def
rgb_to_grayscale
(
def
rgb_to_grayscale
(
inpt
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
rgb_to_grayscale_image_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
return
rgb_to_grayscale_image_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
...
@@ -73,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
...
@@ -73,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_brightness
(
inpt
:
datapoints
.
_InputTypeJIT
,
brightness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_brightness
(
inpt
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
...
@@ -110,7 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
...
@@ -110,7 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_saturation
(
inpt
:
datapoints
.
_InputTypeJIT
,
saturation_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_saturation
(
inpt
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_saturation_image_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
return
adjust_saturation_image_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
...
@@ -149,7 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
...
@@ -149,7 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_contrast
(
inpt
:
datapoints
.
_InputTypeJIT
,
contrast_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_contrast
(
inpt
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
...
@@ -188,7 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
...
@@ -188,7 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_sharpness
(
inpt
:
datapoints
.
_InputTypeJIT
,
sharpness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_sharpness
(
inpt
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_sharpness_image_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
return
adjust_sharpness_image_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
...
@@ -261,7 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
...
@@ -261,7 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_hue
(
inpt
:
datapoints
.
_InputTypeJIT
,
hue_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_hue
(
inpt
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
...
@@ -373,7 +371,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
...
@@ -373,7 +371,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_gamma
(
inpt
:
datapoints
.
_InputTypeJIT
,
gamma
:
float
,
gain
:
float
=
1
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_gamma
(
inpt
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
...
@@ -413,7 +411,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
...
@@ -413,7 +411,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
posterize
(
inpt
:
datapoints
.
_InputTypeJIT
,
bits
:
int
)
->
datapoints
.
_InputTypeJIT
:
def
posterize
(
inpt
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
...
@@ -447,7 +445,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
...
@@ -447,7 +445,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
solarize
(
inpt
:
datapoints
.
_InputTypeJIT
,
threshold
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
solarize
(
inpt
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
...
@@ -475,7 +473,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
...
@@ -475,7 +473,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
autocontrast
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
autocontrast
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
autocontrast_image_tensor
(
inpt
)
return
autocontrast_image_tensor
(
inpt
)
...
@@ -525,7 +523,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -525,7 +523,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
equalize
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
equalize
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
equalize_image_tensor
(
inpt
)
return
equalize_image_tensor
(
inpt
)
...
@@ -615,7 +613,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -615,7 +613,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
invert
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
invert
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
invert_image_tensor
(
inpt
)
return
invert_image_tensor
(
inpt
)
...
@@ -646,7 +644,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -646,7 +644,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
permute_channels
(
inpt
:
datapoints
.
_InputTypeJIT
,
permutation
:
List
[
int
])
->
datapoints
.
_InputTypeJIT
:
def
permute_channels
(
inpt
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
"""Permute the channels of the input according to the given permutation.
"""Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
...
...
torchvision/transforms/v2/functional/_deprecated.py
View file @
641fdd9f
import
warnings
import
warnings
from
typing
import
Any
,
List
,
Union
from
typing
import
Any
,
List
import
torch
import
torch
from
torchvision
import
datapoints
from
torchvision.transforms
import
functional
as
_F
from
torchvision.transforms
import
functional
as
_F
...
@@ -16,7 +15,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
...
@@ -16,7 +15,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return
_F
.
to_tensor
(
inpt
)
return
_F
.
to_tensor
(
inpt
)
def
get_image_size
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
)
->
List
[
int
]:
def
get_image_size
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
warnings
.
warn
(
warnings
.
warn
(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`."
"Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`."
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
641fdd9f
...
@@ -25,7 +25,13 @@ from torchvision.utils import _log_api_usage_once
...
@@ -25,7 +25,13 @@ from torchvision.utils import _log_api_usage_once
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
from
._utils
import
(
_FillTypeJIT
,
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
,
)
def
_check_interpolation
(
interpolation
:
Union
[
InterpolationMode
,
int
])
->
InterpolationMode
:
def
_check_interpolation
(
interpolation
:
Union
[
InterpolationMode
,
int
])
->
InterpolationMode
:
...
@@ -39,7 +45,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
...
@@ -39,7 +45,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
return
interpolation
return
interpolation
def
horizontal_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
horizontal_flip
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
horizontal_flip_image_tensor
(
inpt
)
return
horizontal_flip_image_tensor
(
inpt
)
...
@@ -95,7 +101,7 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -95,7 +101,7 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return
horizontal_flip_image_tensor
(
video
)
return
horizontal_flip_image_tensor
(
video
)
def
vertical_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
vertical_flip
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
vertical_flip_image_tensor
(
inpt
)
return
vertical_flip_image_tensor
(
inpt
)
...
@@ -171,12 +177,12 @@ def _compute_resized_output_size(
...
@@ -171,12 +177,12 @@ def _compute_resized_output_size(
def
resize
(
def
resize
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
size
:
List
[
int
],
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
resize_image_tensor
(
inpt
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
return
resize_image_tensor
(
inpt
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
...
@@ -364,15 +370,15 @@ def resize_video(
...
@@ -364,15 +370,15 @@ def resize_video(
def
affine
(
def
affine
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
angle
:
Union
[
int
,
float
],
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
translate
:
List
[
float
],
scale
:
float
,
scale
:
float
,
shear
:
List
[
float
],
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
affine_image_tensor
(
return
affine_image_tensor
(
inpt
,
inpt
,
...
@@ -549,9 +555,7 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
...
@@ -549,9 +555,7 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
return
int
(
size
[
0
]),
int
(
size
[
1
])
# w, h
return
int
(
size
[
0
]),
int
(
size
[
1
])
# w, h
def
_apply_grid_transform
(
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
_FillTypeJIT
)
->
torch
.
Tensor
:
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
datapoints
.
_FillTypeJIT
)
->
torch
.
Tensor
:
# We are using context knowledge that grid should have float dtype
# We are using context knowledge that grid should have float dtype
fp
=
img
.
dtype
==
grid
.
dtype
fp
=
img
.
dtype
==
grid
.
dtype
...
@@ -592,7 +596,7 @@ def _assert_grid_transform_inputs(
...
@@ -592,7 +596,7 @@ def _assert_grid_transform_inputs(
image
:
torch
.
Tensor
,
image
:
torch
.
Tensor
,
matrix
:
Optional
[
List
[
float
]],
matrix
:
Optional
[
List
[
float
]],
interpolation
:
str
,
interpolation
:
str
,
fill
:
datapoints
.
_FillTypeJIT
,
fill
:
_FillTypeJIT
,
supported_interpolation_modes
:
List
[
str
],
supported_interpolation_modes
:
List
[
str
],
coeffs
:
Optional
[
List
[
float
]]
=
None
,
coeffs
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -657,7 +661,7 @@ def affine_image_tensor(
...
@@ -657,7 +661,7 @@ def affine_image_tensor(
scale
:
float
,
scale
:
float
,
shear
:
List
[
float
],
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
...
@@ -709,7 +713,7 @@ def affine_image_pil(
...
@@ -709,7 +713,7 @@ def affine_image_pil(
scale
:
float
,
scale
:
float
,
shear
:
List
[
float
],
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
...
@@ -868,7 +872,7 @@ def affine_mask(
...
@@ -868,7 +872,7 @@ def affine_mask(
translate
:
List
[
float
],
translate
:
List
[
float
],
scale
:
float
,
scale
:
float
,
shear
:
List
[
float
],
shear
:
List
[
float
],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
if
mask
.
ndim
<
3
:
...
@@ -901,7 +905,7 @@ def _affine_mask_dispatch(
...
@@ -901,7 +905,7 @@ def _affine_mask_dispatch(
translate
:
List
[
float
],
translate
:
List
[
float
],
scale
:
float
,
scale
:
float
,
shear
:
List
[
float
],
shear
:
List
[
float
],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
datapoints
.
Mask
:
)
->
datapoints
.
Mask
:
...
@@ -925,7 +929,7 @@ def affine_video(
...
@@ -925,7 +929,7 @@ def affine_video(
scale
:
float
,
scale
:
float
,
shear
:
List
[
float
],
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
affine_image_tensor
(
return
affine_image_tensor
(
...
@@ -941,13 +945,13 @@ def affine_video(
...
@@ -941,13 +945,13 @@ def affine_video(
def
rotate
(
def
rotate
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
angle
:
float
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
rotate_image_tensor
(
return
rotate_image_tensor
(
inpt
,
angle
=
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
inpt
,
angle
=
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
...
@@ -967,7 +971,7 @@ def rotate_image_tensor(
...
@@ -967,7 +971,7 @@ def rotate_image_tensor(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
...
@@ -1012,7 +1016,7 @@ def rotate_image_pil(
...
@@ -1012,7 +1016,7 @@ def rotate_image_pil(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
...
@@ -1068,7 +1072,7 @@ def rotate_mask(
...
@@ -1068,7 +1072,7 @@ def rotate_mask(
angle
:
float
,
angle
:
float
,
expand
:
bool
=
False
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
mask
=
mask
.
unsqueeze
(
0
)
...
@@ -1097,7 +1101,7 @@ def _rotate_mask_dispatch(
...
@@ -1097,7 +1101,7 @@ def _rotate_mask_dispatch(
angle
:
float
,
angle
:
float
,
expand
:
bool
=
False
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
**
kwargs
,
**
kwargs
,
)
->
datapoints
.
Mask
:
)
->
datapoints
.
Mask
:
output
=
rotate_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
angle
=
angle
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
output
=
rotate_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
angle
=
angle
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
...
@@ -1111,17 +1115,17 @@ def rotate_video(
...
@@ -1111,17 +1115,17 @@ def rotate_video(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
rotate_image_tensor
(
video
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
rotate_image_tensor
(
video
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
def
pad
(
def
pad
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
padding
:
List
[
int
],
padding
:
List
[
int
],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
padding_mode
:
str
=
"constant"
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
pad_image_tensor
(
inpt
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
pad_image_tensor
(
inpt
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
...
@@ -1336,7 +1340,7 @@ def pad_video(
...
@@ -1336,7 +1340,7 @@ def pad_video(
return
pad_image_tensor
(
video
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
pad_image_tensor
(
video
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
def
crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
_InputTypeJIT
:
def
crop
(
inpt
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
crop_image_tensor
(
inpt
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
return
crop_image_tensor
(
inpt
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
...
@@ -1423,13 +1427,13 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
...
@@ -1423,13 +1427,13 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
def
perspective
(
def
perspective
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
perspective_image_tensor
(
return
perspective_image_tensor
(
inpt
,
inpt
,
...
@@ -1507,7 +1511,7 @@ def perspective_image_tensor(
...
@@ -1507,7 +1511,7 @@ def perspective_image_tensor(
startpoints
:
Optional
[
List
[
List
[
int
]]],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
...
@@ -1554,7 +1558,7 @@ def perspective_image_pil(
...
@@ -1554,7 +1558,7 @@ def perspective_image_pil(
startpoints
:
Optional
[
List
[
List
[
int
]]],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BICUBIC
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BICUBIC
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
...
@@ -1679,7 +1683,7 @@ def perspective_mask(
...
@@ -1679,7 +1683,7 @@ def perspective_mask(
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
if
mask
.
ndim
<
3
:
...
@@ -1703,7 +1707,7 @@ def _perspective_mask_dispatch(
...
@@ -1703,7 +1707,7 @@ def _perspective_mask_dispatch(
inpt
:
datapoints
.
Mask
,
inpt
:
datapoints
.
Mask
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
datapoints
.
Mask
:
)
->
datapoints
.
Mask
:
...
@@ -1723,7 +1727,7 @@ def perspective_video(
...
@@ -1723,7 +1727,7 @@ def perspective_video(
startpoints
:
Optional
[
List
[
List
[
int
]]],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
perspective_image_tensor
(
return
perspective_image_tensor
(
...
@@ -1732,11 +1736,11 @@ def perspective_video(
...
@@ -1732,11 +1736,11 @@ def perspective_video(
def
elastic
(
def
elastic
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
elastic_image_tensor
(
inpt
,
displacement
=
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
elastic_image_tensor
(
inpt
,
displacement
=
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
...
@@ -1755,7 +1759,7 @@ def elastic_image_tensor(
...
@@ -1755,7 +1759,7 @@ def elastic_image_tensor(
image
:
torch
.
Tensor
,
image
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
...
@@ -1812,7 +1816,7 @@ def elastic_image_pil(
...
@@ -1812,7 +1816,7 @@ def elastic_image_pil(
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
displacement
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
t_img
=
pil_to_tensor
(
image
)
t_img
=
pil_to_tensor
(
image
)
output
=
elastic_image_tensor
(
t_img
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
output
=
elastic_image_tensor
(
t_img
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
...
@@ -1895,7 +1899,7 @@ def _elastic_bounding_boxes_dispatch(
...
@@ -1895,7 +1899,7 @@ def _elastic_bounding_boxes_dispatch(
def
elastic_mask
(
def
elastic_mask
(
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
mask
=
mask
.
unsqueeze
(
0
)
...
@@ -1913,7 +1917,7 @@ def elastic_mask(
...
@@ -1913,7 +1917,7 @@ def elastic_mask(
@
_register_kernel_internal
(
elastic
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
elastic
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_elastic_mask_dispatch
(
def
_elastic_mask_dispatch
(
inpt
:
datapoints
.
Mask
,
displacement
:
torch
.
Tensor
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
**
kwargs
inpt
:
datapoints
.
Mask
,
displacement
:
torch
.
Tensor
,
fill
:
_FillTypeJIT
=
None
,
**
kwargs
)
->
datapoints
.
Mask
:
)
->
datapoints
.
Mask
:
output
=
elastic_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
displacement
=
displacement
,
fill
=
fill
)
output
=
elastic_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
displacement
=
displacement
,
fill
=
fill
)
return
datapoints
.
Mask
.
wrap_like
(
inpt
,
output
)
return
datapoints
.
Mask
.
wrap_like
(
inpt
,
output
)
...
@@ -1924,12 +1928,12 @@ def elastic_video(
...
@@ -1924,12 +1928,12 @@ def elastic_video(
video
:
torch
.
Tensor
,
video
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
elastic_image_tensor
(
video
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
elastic_image_tensor
(
video
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
def
center_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
output_size
:
List
[
int
])
->
datapoints
.
_InputTypeJIT
:
def
center_crop
(
inpt
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
center_crop_image_tensor
(
inpt
,
output_size
=
output_size
)
return
center_crop_image_tensor
(
inpt
,
output_size
=
output_size
)
...
@@ -2049,7 +2053,7 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
...
@@ -2049,7 +2053,7 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
def
resized_crop
(
def
resized_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
top
:
int
,
top
:
int
,
left
:
int
,
left
:
int
,
height
:
int
,
height
:
int
,
...
@@ -2057,7 +2061,7 @@ def resized_crop(
...
@@ -2057,7 +2061,7 @@ def resized_crop(
size
:
List
[
int
],
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
resized_crop_image_tensor
(
return
resized_crop_image_tensor
(
inpt
,
inpt
,
...
@@ -2201,14 +2205,8 @@ def resized_crop_video(
...
@@ -2201,14 +2205,8 @@ def resized_crop_video(
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
five_crop
(
def
five_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
size
:
List
[
int
]
inpt
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
]:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
five_crop_image_tensor
(
inpt
,
size
=
size
)
return
five_crop_image_tensor
(
inpt
,
size
=
size
)
...
@@ -2280,18 +2278,18 @@ def five_crop_video(
...
@@ -2280,18 +2278,18 @@ def five_crop_video(
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
ten_crop
(
def
ten_crop
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
inpt
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
)
->
Tuple
[
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
]:
]:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
ten_crop_image_tensor
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
return
ten_crop_image_tensor
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
641fdd9f
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
import
PIL.Image
import
PIL.Image
import
torch
import
torch
...
@@ -12,7 +12,7 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte
...
@@ -12,7 +12,7 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
)
->
List
[
int
]:
def
get_dimensions
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
get_dimensions_image_tensor
(
inpt
)
return
get_dimensions_image_tensor
(
inpt
)
...
@@ -45,7 +45,7 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
...
@@ -45,7 +45,7 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_channels
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
)
->
int
:
def
get_num_channels
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
get_num_channels_image_tensor
(
inpt
)
return
get_num_channels_image_tensor
(
inpt
)
...
@@ -81,7 +81,7 @@ def get_num_channels_video(video: torch.Tensor) -> int:
...
@@ -81,7 +81,7 @@ def get_num_channels_video(video: torch.Tensor) -> int:
get_image_num_channels
=
get_num_channels
get_image_num_channels
=
get_num_channels
def
get_size
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
List
[
int
]:
def
get_size
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
get_size_image_tensor
(
inpt
)
return
get_size_image_tensor
(
inpt
)
...
@@ -124,7 +124,7 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
...
@@ -124,7 +124,7 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
@
_register_unsupported_type
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_frames
(
inpt
:
datapoints
.
_VideoTypeJIT
)
->
int
:
def
get_num_frames
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
get_num_frames_video
(
inpt
)
return
get_num_frames_video
(
inpt
)
...
@@ -201,11 +201,11 @@ def _convert_format_bounding_boxes(
...
@@ -201,11 +201,11 @@ def _convert_format_bounding_boxes(
def
convert_format_bounding_boxes
(
def
convert_format_bounding_boxes
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
old_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
old_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
...
@@ -252,10 +252,10 @@ def _clamp_bounding_boxes(
...
@@ -252,10 +252,10 @@ def _clamp_bounding_boxes(
def
clamp_bounding_boxes
(
def
clamp_bounding_boxes
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
format
:
Optional
[
BoundingBoxFormat
]
=
None
,
format
:
Optional
[
BoundingBoxFormat
]
=
None
,
canvas_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
canvas_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_boxes
)
_log_api_usage_once
(
clamp_bounding_boxes
)
...
...
torchvision/transforms/v2/functional/_misc.py
View file @
641fdd9f
import
math
import
math
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
import
PIL.Image
import
PIL.Image
import
torch
import
torch
...
@@ -17,7 +17,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
...
@@ -17,7 +17,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
normalize
(
def
normalize
(
inpt
:
Union
[
datapoints
.
_TensorImageTypeJIT
,
datapoints
.
_TensorVideoTypeJIT
]
,
inpt
:
torch
.
Tensor
,
mean
:
List
[
float
],
mean
:
List
[
float
],
std
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
...
@@ -74,9 +74,7 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
...
@@ -74,9 +74,7 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
gaussian_blur
(
def
gaussian_blur
(
inpt
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
inpt
:
datapoints
.
_InputTypeJIT
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
datapoints
.
_InputTypeJIT
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
...
@@ -185,9 +183,7 @@ def gaussian_blur_video(
...
@@ -185,9 +183,7 @@ def gaussian_blur_video(
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
to_dtype
(
def
to_dtype
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
inpt
:
datapoints
.
_InputTypeJIT
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
datapoints
.
_InputTypeJIT
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
to_dtype_image_tensor
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
return
to_dtype_image_tensor
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
...
@@ -278,8 +274,6 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale:
...
@@ -278,8 +274,6 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale:
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_to_dtype_tensor_dispatch
(
def
_to_dtype_tensor_dispatch
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
inpt
:
datapoints
.
_InputTypeJIT
,
dtype
:
torch
.
dtype
,
scale
:
bool
=
False
)
->
datapoints
.
_InputTypeJIT
:
# We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type
# We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type
return
inpt
.
to
(
dtype
)
return
inpt
.
to
(
dtype
)
torchvision/transforms/v2/functional/_temporal.py
View file @
641fdd9f
...
@@ -11,7 +11,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
...
@@ -11,7 +11,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
@
_register_explicit_noop
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
)
def
uniform_temporal_subsample
(
inpt
:
datapoints
.
_VideoTypeJIT
,
num_samples
:
int
)
->
datapoints
.
_VideoTypeJIT
:
def
uniform_temporal_subsample
(
inpt
:
torch
.
Tensor
,
num_samples
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
=
num_samples
)
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
=
num_samples
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment