Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
641fdd9f
Unverified
Commit
641fdd9f
authored
Aug 09, 2023
by
Philip Meier
Committed by
GitHub
Aug 09, 2023
Browse files
remove custom types defintions from datapoints module (#7814)
parent
6b020798
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
137 additions
and
177 deletions
+137
-177
torchvision/datapoints/__init__.py
torchvision/datapoints/__init__.py
+3
-3
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+1
-8
torchvision/datapoints/_image.py
torchvision/datapoints/_image.py
+0
-6
torchvision/datapoints/_video.py
torchvision/datapoints/_video.py
+0
-6
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+5
-5
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+2
-2
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+2
-6
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+14
-10
torchvision/transforms/v2/_color.py
torchvision/transforms/v2/_color.py
+1
-3
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+8
-10
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+1
-3
torchvision/transforms/v2/_temporal.py
torchvision/transforms/v2/_temporal.py
+1
-2
torchvision/transforms/v2/_utils.py
torchvision/transforms/v2/_utils.py
+2
-3
torchvision/transforms/v2/functional/_augment.py
torchvision/transforms/v2/functional/_augment.py
+2
-4
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+14
-16
torchvision/transforms/v2/functional/_deprecated.py
torchvision/transforms/v2/functional/_deprecated.py
+2
-3
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+64
-66
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+9
-9
torchvision/transforms/v2/functional/_misc.py
torchvision/transforms/v2/functional/_misc.py
+5
-11
torchvision/transforms/v2/functional/_temporal.py
torchvision/transforms/v2/functional/_temporal.py
+1
-1
No files found.
torchvision/datapoints/__init__.py
View file @
641fdd9f
from
torchvision
import
_BETA_TRANSFORMS_WARNING
,
_WARN_ABOUT_BETA_TRANSFORMS
from
._bounding_box
import
BoundingBoxes
,
BoundingBoxFormat
from
._datapoint
import
_FillType
,
_FillTypeJIT
,
_InputType
,
_InputTypeJIT
,
Datapoint
from
._image
import
_ImageType
,
_ImageTypeJIT
,
_TensorImageType
,
_TensorImageTypeJIT
,
Image
from
._datapoint
import
Datapoint
from
._image
import
Image
from
._mask
import
Mask
from
._video
import
_TensorVideoType
,
_TensorVideoTypeJIT
,
_VideoType
,
_VideoTypeJIT
,
Video
from
._video
import
Video
if
_WARN_ABOUT_BETA_TRANSFORMS
:
import
warnings
...
...
torchvision/datapoints/_datapoint.py
View file @
641fdd9f
from
__future__
import
annotations
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
import
PIL.Image
import
torch
from
torch._C
import
DisableTorchFunctionSubclass
from
torch.types
import
_device
,
_dtype
,
_size
D
=
TypeVar
(
"D"
,
bound
=
"Datapoint"
)
_FillType
=
Union
[
int
,
float
,
Sequence
[
int
],
Sequence
[
float
],
None
]
_FillTypeJIT
=
Optional
[
List
[
float
]]
class
Datapoint
(
torch
.
Tensor
):
...
...
@@ -132,7 +129,3 @@ class Datapoint(torch.Tensor):
# `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBoxes.clone()`.
return
self
.
detach
().
clone
().
requires_grad_
(
self
.
requires_grad
)
# type: ignore[return-value]
_InputType
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
Datapoint
]
_InputTypeJIT
=
torch
.
Tensor
torchvision/datapoints/_image.py
View file @
641fdd9f
...
...
@@ -45,9 +45,3 @@ class Image(Datapoint):
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
()
_ImageType
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
Image
]
_ImageTypeJIT
=
torch
.
Tensor
_TensorImageType
=
Union
[
torch
.
Tensor
,
Image
]
_TensorImageTypeJIT
=
torch
.
Tensor
torchvision/datapoints/_video.py
View file @
641fdd9f
...
...
@@ -35,9 +35,3 @@ class Video(Datapoint):
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
()
_VideoType
=
Union
[
torch
.
Tensor
,
Video
]
_VideoTypeJIT
=
torch
.
Tensor
_TensorVideoType
=
Union
[
torch
.
Tensor
,
Video
]
_TensorVideoTypeJIT
=
torch
.
Tensor
torchvision/prototype/transforms/_augment.py
View file @
641fdd9f
...
...
@@ -26,15 +26,15 @@ class SimpleCopyPaste(Transform):
def
_copy_paste
(
self
,
image
:
datapoints
.
_Tensor
Image
Type
,
image
:
Union
[
torch
.
Tensor
,
datapoints
.
Image
]
,
target
:
Dict
[
str
,
Any
],
paste_image
:
datapoints
.
_Tensor
Image
Type
,
paste_image
:
Union
[
torch
.
Tensor
,
datapoints
.
Image
]
,
paste_target
:
Dict
[
str
,
Any
],
random_selection
:
torch
.
Tensor
,
blending
:
bool
,
resize_interpolation
:
F
.
InterpolationMode
,
antialias
:
Optional
[
bool
],
)
->
Tuple
[
datapoints
.
_TensorImageType
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
Any
]]:
paste_masks
=
paste_target
[
"masks"
].
wrap_like
(
paste_target
[
"masks"
],
paste_target
[
"masks"
][
random_selection
])
paste_boxes
=
paste_target
[
"boxes"
].
wrap_like
(
paste_target
[
"boxes"
],
paste_target
[
"boxes"
][
random_selection
])
...
...
@@ -106,7 +106,7 @@ class SimpleCopyPaste(Transform):
def
_extract_image_targets
(
self
,
flat_sample
:
List
[
Any
]
)
->
Tuple
[
List
[
datapoints
.
_Tensor
Image
Type
],
List
[
Dict
[
str
,
Any
]]]:
)
->
Tuple
[
List
[
Union
[
torch
.
Tensor
,
datapoints
.
Image
]
],
List
[
Dict
[
str
,
Any
]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
...
...
@@ -137,7 +137,7 @@ class SimpleCopyPaste(Transform):
def
_insert_outputs
(
self
,
flat_sample
:
List
[
Any
],
output_images
:
List
[
datapoints
.
_TensorImageType
],
output_images
:
List
[
torch
.
Tensor
],
output_targets
:
List
[
Dict
[
str
,
Any
]],
)
->
None
:
c0
,
c1
,
c2
,
c3
=
0
,
0
,
0
,
0
...
...
torchvision/prototype/transforms/_geometry.py
View file @
641fdd9f
...
...
@@ -6,7 +6,7 @@ import torch
from
torchvision
import
datapoints
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2._utils
import
_FillType
,
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2.utils
import
get_bounding_boxes
,
has_any
,
is_simple_tensor
,
query_size
...
...
@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
padding_mode
:
str
=
"constant"
,
)
->
None
:
super
().
__init__
()
...
...
torchvision/prototype/transforms/_misc.py
View file @
641fdd9f
...
...
@@ -39,9 +39,7 @@ class PermuteDimensions(Transform):
)
self
.
dims
=
dims
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_TensorImageType
,
datapoints
.
_TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
torch
.
Tensor
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
dims
=
self
.
dims
[
type
(
inpt
)]
if
dims
is
None
:
return
inpt
.
as_subclass
(
torch
.
Tensor
)
...
...
@@ -63,9 +61,7 @@ class TransposeDimensions(Transform):
)
self
.
dims
=
dims
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_TensorImageType
,
datapoints
.
_TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
torch
.
Tensor
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
dims
=
self
.
dims
[
type
(
inpt
)]
if
dims
is
None
:
return
inpt
.
as_subclass
(
torch
.
Tensor
)
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
641fdd9f
...
...
@@ -10,17 +10,21 @@ from torchvision.transforms import _functional_tensor as _FT
from
torchvision.transforms.v2
import
AutoAugmentPolicy
,
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._meta
import
get_size
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
from
._utils
import
_get_fill
,
_setup_fill_arg
from
.utils
import
check_type
,
is_simple_tensor
ImageOrVideo
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
Video
]
class
_AutoAugmentBase
(
Transform
):
def
__init__
(
self
,
*
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -35,7 +39,7 @@ class _AutoAugmentBase(Transform):
self
,
inputs
:
Any
,
unsupported_types
:
Tuple
[
Type
,
...]
=
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
),
)
->
Tuple
[
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
]:
)
->
Tuple
[
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
ImageOrVideo
]:
flat_inputs
,
spec
=
tree_flatten
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
...
...
@@ -68,7 +72,7 @@ class _AutoAugmentBase(Transform):
def
_unflatten_and_insert_image_or_video
(
self
,
flat_inputs_with_spec
:
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
image_or_video
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
,
image_or_video
:
ImageOrVideo
,
)
->
Any
:
flat_inputs
,
spec
,
idx
=
flat_inputs_with_spec
flat_inputs
[
idx
]
=
image_or_video
...
...
@@ -76,12 +80,12 @@ class _AutoAugmentBase(Transform):
def
_apply_image_or_video_transform
(
self
,
image
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
,
image
:
ImageOrVideo
,
transform_id
:
str
,
magnitude
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
],
fill
:
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillTypeJIT
],
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
:
fill
:
Dict
[
Union
[
Type
,
str
],
_FillTypeJIT
],
)
->
ImageOrVideo
:
fill_
=
_get_fill
(
fill
,
type
(
image
))
if
transform_id
==
"Identity"
:
...
...
@@ -214,7 +218,7 @@ class AutoAugment(_AutoAugmentBase):
self
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
policy
=
policy
...
...
@@ -394,7 +398,7 @@ class RandAugment(_AutoAugmentBase):
magnitude
:
int
=
9
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
num_ops
=
num_ops
...
...
@@ -467,7 +471,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
):
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
num_magnitude_bins
=
num_magnitude_bins
...
...
@@ -550,7 +554,7 @@ class AugMix(_AutoAugmentBase):
alpha
:
float
=
1.0
,
all_ops
:
bool
=
True
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
_PARAMETER_MAX
=
10
...
...
torchvision/transforms/v2/_color.py
View file @
641fdd9f
...
...
@@ -261,9 +261,7 @@ class RandomPhotometricDistort(Transform):
params
[
"channel_permutation"
]
=
torch
.
randperm
(
num_channels
)
if
torch
.
rand
(
1
)
<
self
.
p
else
None
return
params
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"brightness_factor"
]
is
not
None
:
inpt
=
F
.
adjust_brightness
(
inpt
,
brightness_factor
=
params
[
"brightness_factor"
])
if
params
[
"contrast_factor"
]
is
not
None
and
params
[
"contrast_before"
]:
...
...
torchvision/transforms/v2/_geometry.py
View file @
641fdd9f
...
...
@@ -11,6 +11,7 @@ from torchvision.ops.boxes import box_iou
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._utils
import
_FillType
from
._transform
import
_RandomApplyTransform
from
._utils
import
(
...
...
@@ -311,9 +312,6 @@ class RandomResizedCrop(Transform):
)
ImageOrVideoTypeJIT
=
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
class
FiveCrop
(
Transform
):
"""[BETA] Crop the image or video into four corners and the central crop.
...
...
@@ -459,7 +457,7 @@ class Pad(Transform):
def
__init__
(
self
,
padding
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -514,7 +512,7 @@ class RandomZoomOut(_RandomApplyTransform):
def
__init__
(
self
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
side_range
:
Sequence
[
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
,
)
->
None
:
...
...
@@ -592,7 +590,7 @@ class RandomRotation(Transform):
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
)
->
None
:
super
().
__init__
()
self
.
degrees
=
_setup_angle
(
degrees
,
name
=
"degrees"
,
req_sizes
=
(
2
,))
...
...
@@ -674,7 +672,7 @@ class RandomAffine(Transform):
scale
:
Optional
[
Sequence
[
float
]]
=
None
,
shear
:
Optional
[
Union
[
int
,
float
,
Sequence
[
float
]]]
=
None
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -812,7 +810,7 @@ class RandomCrop(Transform):
size
:
Union
[
int
,
Sequence
[
int
]],
padding
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
pad_if_needed
:
bool
=
False
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -931,7 +929,7 @@ class RandomPerspective(_RandomApplyTransform):
distortion_scale
:
float
=
0.5
,
p
:
float
=
0.5
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
)
->
None
:
super
().
__init__
(
p
=
p
)
...
...
@@ -1033,7 +1031,7 @@ class ElasticTransform(Transform):
alpha
:
Union
[
float
,
Sequence
[
float
]]
=
50.0
,
sigma
:
Union
[
float
,
Sequence
[
float
]]
=
5.0
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
)
->
None
:
super
().
__init__
()
self
.
alpha
=
_setup_float_or_seq
(
alpha
,
"alpha"
,
2
)
...
...
torchvision/transforms/v2/_misc.py
View file @
641fdd9f
...
...
@@ -169,9 +169,7 @@ class Normalize(Transform):
if
has_any
(
sample
,
PIL
.
Image
.
Image
):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images."
)
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_TensorImageType
,
datapoints
.
_TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
normalize
(
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
...
...
torchvision/transforms/v2/_temporal.py
View file @
641fdd9f
from
typing
import
Any
,
Dict
import
torch
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
...
...
@@ -25,5 +24,5 @@ class UniformTemporalSubsample(Transform):
super
().
__init__
()
self
.
num_samples
=
num_samples
def
_transform
(
self
,
inpt
:
datapoints
.
_VideoType
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
_VideoType
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
uniform_temporal_subsample
(
inpt
,
self
.
num_samples
)
torchvision/transforms/v2/_utils.py
View file @
641fdd9f
...
...
@@ -5,9 +5,8 @@ from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
import
torch
from
torchvision
import
datapoints
from
torchvision.datapoints._datapoint
import
_FillType
,
_FillTypeJIT
from
torchvision.transforms.transforms
import
_check_sequence_input
,
_setup_angle
,
_setup_size
# noqa: F401
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
def
_setup_float_or_seq
(
arg
:
Union
[
float
,
Sequence
[
float
]],
name
:
str
,
req_size
:
int
=
2
)
->
Sequence
[
float
]:
...
...
@@ -36,7 +35,7 @@ def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -
raise
TypeError
(
"Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed."
)
def
_convert_fill_arg
(
fill
:
datapoints
.
_FillType
)
->
datapoints
.
_FillTypeJIT
:
def
_convert_fill_arg
(
fill
:
_FillType
)
->
_FillTypeJIT
:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
...
...
torchvision/transforms/v2/functional/_augment.py
View file @
641fdd9f
from
typing
import
Union
import
PIL.Image
import
torch
...
...
@@ -12,14 +10,14 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
warn_passthrough
=
True
)
def
erase
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
,
inpt
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
...
...
torchvision/transforms/v2/functional/_color.py
View file @
641fdd9f
from
typing
import
List
,
Union
from
typing
import
List
import
PIL.Image
import
torch
...
...
@@ -16,9 +16,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
)
def
rgb_to_grayscale
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]:
def
rgb_to_grayscale
(
inpt
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
rgb_to_grayscale_image_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
...
...
@@ -73,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_brightness
(
inpt
:
datapoints
.
_InputTypeJIT
,
brightness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_brightness
(
inpt
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
...
...
@@ -110,7 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_saturation
(
inpt
:
datapoints
.
_InputTypeJIT
,
saturation_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_saturation
(
inpt
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_saturation_image_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
...
...
@@ -149,7 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_contrast
(
inpt
:
datapoints
.
_InputTypeJIT
,
contrast_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_contrast
(
inpt
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
...
...
@@ -188,7 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_sharpness
(
inpt
:
datapoints
.
_InputTypeJIT
,
sharpness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_sharpness
(
inpt
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_sharpness_image_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
...
...
@@ -261,7 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_hue
(
inpt
:
datapoints
.
_InputTypeJIT
,
hue_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_hue
(
inpt
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
...
...
@@ -373,7 +371,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_gamma
(
inpt
:
datapoints
.
_InputTypeJIT
,
gamma
:
float
,
gain
:
float
=
1
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_gamma
(
inpt
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
...
...
@@ -413,7 +411,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
posterize
(
inpt
:
datapoints
.
_InputTypeJIT
,
bits
:
int
)
->
datapoints
.
_InputTypeJIT
:
def
posterize
(
inpt
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
...
...
@@ -447,7 +445,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
solarize
(
inpt
:
datapoints
.
_InputTypeJIT
,
threshold
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
solarize
(
inpt
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
...
...
@@ -475,7 +473,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
autocontrast
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
autocontrast
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
autocontrast_image_tensor
(
inpt
)
...
...
@@ -525,7 +523,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
equalize
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
equalize
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
equalize_image_tensor
(
inpt
)
...
...
@@ -615,7 +613,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
invert
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
invert
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
invert_image_tensor
(
inpt
)
...
...
@@ -646,7 +644,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
permute_channels
(
inpt
:
datapoints
.
_InputTypeJIT
,
permutation
:
List
[
int
])
->
datapoints
.
_InputTypeJIT
:
def
permute_channels
(
inpt
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
"""Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
...
...
torchvision/transforms/v2/functional/_deprecated.py
View file @
641fdd9f
import
warnings
from
typing
import
Any
,
List
,
Union
from
typing
import
Any
,
List
import
torch
from
torchvision
import
datapoints
from
torchvision.transforms
import
functional
as
_F
...
...
@@ -16,7 +15,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return
_F
.
to_tensor
(
inpt
)
def
get_image_size
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
)
->
List
[
int
]:
def
get_image_size
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
warnings
.
warn
(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`."
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
641fdd9f
...
...
@@ -25,7 +25,13 @@ from torchvision.utils import _log_api_usage_once
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
from
._utils
import
(
_FillTypeJIT
,
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
,
)
def
_check_interpolation
(
interpolation
:
Union
[
InterpolationMode
,
int
])
->
InterpolationMode
:
...
...
@@ -39,7 +45,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
return
interpolation
def
horizontal_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
horizontal_flip
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
horizontal_flip_image_tensor
(
inpt
)
...
...
@@ -95,7 +101,7 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return
horizontal_flip_image_tensor
(
video
)
def
vertical_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
vertical_flip
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
vertical_flip_image_tensor
(
inpt
)
...
...
@@ -171,12 +177,12 @@ def _compute_resized_output_size(
def
resize
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
resize_image_tensor
(
inpt
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
...
...
@@ -364,15 +370,15 @@ def resize_video(
def
affine
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
affine_image_tensor
(
inpt
,
...
...
@@ -549,9 +555,7 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
return
int
(
size
[
0
]),
int
(
size
[
1
])
# w, h
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
datapoints
.
_FillTypeJIT
)
->
torch
.
Tensor
:
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
_FillTypeJIT
)
->
torch
.
Tensor
:
# We are using context knowledge that grid should have float dtype
fp
=
img
.
dtype
==
grid
.
dtype
...
...
@@ -592,7 +596,7 @@ def _assert_grid_transform_inputs(
image
:
torch
.
Tensor
,
matrix
:
Optional
[
List
[
float
]],
interpolation
:
str
,
fill
:
datapoints
.
_FillTypeJIT
,
fill
:
_FillTypeJIT
,
supported_interpolation_modes
:
List
[
str
],
coeffs
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
...
...
@@ -657,7 +661,7 @@ def affine_image_tensor(
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -709,7 +713,7 @@ def affine_image_pil(
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
PIL
.
Image
.
Image
:
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -868,7 +872,7 @@ def affine_mask(
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
...
...
@@ -901,7 +905,7 @@ def _affine_mask_dispatch(
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoints
.
Mask
:
...
...
@@ -925,7 +929,7 @@ def affine_video(
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
return
affine_image_tensor
(
...
...
@@ -941,13 +945,13 @@ def affine_video(
def
rotate
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
rotate_image_tensor
(
inpt
,
angle
=
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
...
...
@@ -967,7 +971,7 @@ def rotate_image_tensor(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -1012,7 +1016,7 @@ def rotate_image_pil(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
PIL
.
Image
.
Image
:
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -1068,7 +1072,7 @@ def rotate_mask(
angle
:
float
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
...
...
@@ -1097,7 +1101,7 @@ def _rotate_mask_dispatch(
angle
:
float
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
**
kwargs
,
)
->
datapoints
.
Mask
:
output
=
rotate_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
angle
=
angle
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
...
...
@@ -1111,17 +1115,17 @@ def rotate_video(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
return
rotate_image_tensor
(
video
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
def
pad
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
padding
:
List
[
int
],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
pad_image_tensor
(
inpt
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
...
...
@@ -1336,7 +1340,7 @@ def pad_video(
return
pad_image_tensor
(
video
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
def
crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
_InputTypeJIT
:
def
crop
(
inpt
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
crop_image_tensor
(
inpt
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
...
...
@@ -1423,13 +1427,13 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
def
perspective
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
perspective_image_tensor
(
inpt
,
...
...
@@ -1507,7 +1511,7 @@ def perspective_image_tensor(
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
...
...
@@ -1554,7 +1558,7 @@ def perspective_image_pil(
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BICUBIC
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
PIL
.
Image
.
Image
:
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
...
...
@@ -1679,7 +1683,7 @@ def perspective_mask(
mask
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
...
...
@@ -1703,7 +1707,7 @@ def _perspective_mask_dispatch(
inpt
:
datapoints
.
Mask
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoints
.
Mask
:
...
...
@@ -1723,7 +1727,7 @@ def perspective_video(
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
return
perspective_image_tensor
(
...
...
@@ -1732,11 +1736,11 @@ def perspective_video(
def
elastic
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
elastic_image_tensor
(
inpt
,
displacement
=
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
...
...
@@ -1755,7 +1759,7 @@ def elastic_image_tensor(
image
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -1812,7 +1816,7 @@ def elastic_image_pil(
image
:
PIL
.
Image
.
Image
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
PIL
.
Image
.
Image
:
t_img
=
pil_to_tensor
(
image
)
output
=
elastic_image_tensor
(
t_img
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
...
...
@@ -1895,7 +1899,7 @@ def _elastic_bounding_boxes_dispatch(
def
elastic_mask
(
mask
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
...
...
@@ -1913,7 +1917,7 @@ def elastic_mask(
@
_register_kernel_internal
(
elastic
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_elastic_mask_dispatch
(
inpt
:
datapoints
.
Mask
,
displacement
:
torch
.
Tensor
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
**
kwargs
inpt
:
datapoints
.
Mask
,
displacement
:
torch
.
Tensor
,
fill
:
_FillTypeJIT
=
None
,
**
kwargs
)
->
datapoints
.
Mask
:
output
=
elastic_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
displacement
=
displacement
,
fill
=
fill
)
return
datapoints
.
Mask
.
wrap_like
(
inpt
,
output
)
...
...
@@ -1924,12 +1928,12 @@ def elastic_video(
video
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
return
elastic_image_tensor
(
video
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
def
center_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
output_size
:
List
[
int
])
->
datapoints
.
_InputTypeJIT
:
def
center_crop
(
inpt
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
center_crop_image_tensor
(
inpt
,
output_size
=
output_size
)
...
...
@@ -2049,7 +2053,7 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
def
resized_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
...
...
@@ -2057,7 +2061,7 @@ def resized_crop(
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
resized_crop_image_tensor
(
inpt
,
...
...
@@ -2201,14 +2205,8 @@ def resized_crop_video(
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
five_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
size
:
List
[
int
]
)
->
Tuple
[
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
]:
inpt
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
torch
.
jit
.
is_scripting
():
return
five_crop_image_tensor
(
inpt
,
size
=
size
)
...
...
@@ -2280,18 +2278,18 @@ def five_crop_video(
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
ten_crop
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
inpt
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
]:
if
torch
.
jit
.
is_scripting
():
return
ten_crop_image_tensor
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
641fdd9f
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
import
PIL.Image
import
torch
...
...
@@ -12,7 +12,7 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
)
->
List
[
int
]:
def
get_dimensions
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
if
torch
.
jit
.
is_scripting
():
return
get_dimensions_image_tensor
(
inpt
)
...
...
@@ -45,7 +45,7 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_channels
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
)
->
int
:
def
get_num_channels
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
return
get_num_channels_image_tensor
(
inpt
)
...
...
@@ -81,7 +81,7 @@ def get_num_channels_video(video: torch.Tensor) -> int:
get_image_num_channels
=
get_num_channels
def
get_size
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
List
[
int
]:
def
get_size
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
if
torch
.
jit
.
is_scripting
():
return
get_size_image_tensor
(
inpt
)
...
...
@@ -124,7 +124,7 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
@
_register_unsupported_type
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_frames
(
inpt
:
datapoints
.
_VideoTypeJIT
)
->
int
:
def
get_num_frames
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
return
get_num_frames_video
(
inpt
)
...
...
@@ -201,11 +201,11 @@ def _convert_format_bounding_boxes(
def
convert_format_bounding_boxes
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
old_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
inplace
:
bool
=
False
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
...
...
@@ -252,10 +252,10 @@ def _clamp_bounding_boxes(
def
clamp_bounding_boxes
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
format
:
Optional
[
BoundingBoxFormat
]
=
None
,
canvas_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_boxes
)
...
...
torchvision/transforms/v2/functional/_misc.py
View file @
641fdd9f
import
math
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
import
PIL.Image
import
torch
...
...
@@ -17,7 +17,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
normalize
(
inpt
:
Union
[
datapoints
.
_TensorImageTypeJIT
,
datapoints
.
_TensorVideoTypeJIT
]
,
inpt
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
,
...
...
@@ -74,9 +74,7 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
gaussian_blur
(
inpt
:
datapoints
.
_InputTypeJIT
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
datapoints
.
_InputTypeJIT
:
def
gaussian_blur
(
inpt
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
...
...
@@ -185,9 +183,7 @@ def gaussian_blur_video(
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
to_dtype
(
inpt
:
datapoints
.
_InputTypeJIT
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
datapoints
.
_InputTypeJIT
:
def
to_dtype
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
to_dtype_image_tensor
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
...
...
@@ -278,8 +274,6 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale:
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_to_dtype_tensor_dispatch
(
inpt
:
datapoints
.
_InputTypeJIT
,
dtype
:
torch
.
dtype
,
scale
:
bool
=
False
)
->
datapoints
.
_InputTypeJIT
:
def
_to_dtype_tensor_dispatch
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
# We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type
return
inpt
.
to
(
dtype
)
torchvision/transforms/v2/functional/_temporal.py
View file @
641fdd9f
...
...
@@ -11,7 +11,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
uniform_temporal_subsample
(
inpt
:
datapoints
.
_VideoTypeJIT
,
num_samples
:
int
)
->
datapoints
.
_VideoTypeJIT
:
def
uniform_temporal_subsample
(
inpt
:
torch
.
Tensor
,
num_samples
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
=
num_samples
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment