Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
641fdd9f
"...text-generation-inference.git" did not exist on "53ee09c0b0004777f029f594ce44cffa6350ed08"
Unverified
Commit
641fdd9f
authored
Aug 09, 2023
by
Philip Meier
Committed by
GitHub
Aug 09, 2023
Browse files
remove custom types defintions from datapoints module (#7814)
parent
6b020798
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
137 additions
and
177 deletions
+137
-177
torchvision/datapoints/__init__.py
torchvision/datapoints/__init__.py
+3
-3
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+1
-8
torchvision/datapoints/_image.py
torchvision/datapoints/_image.py
+0
-6
torchvision/datapoints/_video.py
torchvision/datapoints/_video.py
+0
-6
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+5
-5
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+2
-2
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+2
-6
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+14
-10
torchvision/transforms/v2/_color.py
torchvision/transforms/v2/_color.py
+1
-3
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+8
-10
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+1
-3
torchvision/transforms/v2/_temporal.py
torchvision/transforms/v2/_temporal.py
+1
-2
torchvision/transforms/v2/_utils.py
torchvision/transforms/v2/_utils.py
+2
-3
torchvision/transforms/v2/functional/_augment.py
torchvision/transforms/v2/functional/_augment.py
+2
-4
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+14
-16
torchvision/transforms/v2/functional/_deprecated.py
torchvision/transforms/v2/functional/_deprecated.py
+2
-3
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+64
-66
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+9
-9
torchvision/transforms/v2/functional/_misc.py
torchvision/transforms/v2/functional/_misc.py
+5
-11
torchvision/transforms/v2/functional/_temporal.py
torchvision/transforms/v2/functional/_temporal.py
+1
-1
No files found.
torchvision/datapoints/__init__.py
View file @
641fdd9f
from
torchvision
import
_BETA_TRANSFORMS_WARNING
,
_WARN_ABOUT_BETA_TRANSFORMS
from
._bounding_box
import
BoundingBoxes
,
BoundingBoxFormat
from
._datapoint
import
_FillType
,
_FillTypeJIT
,
_InputType
,
_InputTypeJIT
,
Datapoint
from
._image
import
_ImageType
,
_ImageTypeJIT
,
_TensorImageType
,
_TensorImageTypeJIT
,
Image
from
._datapoint
import
Datapoint
from
._image
import
Image
from
._mask
import
Mask
from
._video
import
_TensorVideoType
,
_TensorVideoTypeJIT
,
_VideoType
,
_VideoTypeJIT
,
Video
from
._video
import
Video
if
_WARN_ABOUT_BETA_TRANSFORMS
:
import
warnings
...
...
torchvision/datapoints/_datapoint.py
View file @
641fdd9f
from
__future__
import
annotations
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
import
PIL.Image
import
torch
from
torch._C
import
DisableTorchFunctionSubclass
from
torch.types
import
_device
,
_dtype
,
_size
D
=
TypeVar
(
"D"
,
bound
=
"Datapoint"
)
_FillType
=
Union
[
int
,
float
,
Sequence
[
int
],
Sequence
[
float
],
None
]
_FillTypeJIT
=
Optional
[
List
[
float
]]
class
Datapoint
(
torch
.
Tensor
):
...
...
@@ -132,7 +129,3 @@ class Datapoint(torch.Tensor):
# `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBoxes.clone()`.
return
self
.
detach
().
clone
().
requires_grad_
(
self
.
requires_grad
)
# type: ignore[return-value]
_InputType
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
Datapoint
]
_InputTypeJIT
=
torch
.
Tensor
torchvision/datapoints/_image.py
View file @
641fdd9f
...
...
@@ -45,9 +45,3 @@ class Image(Datapoint):
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
()
_ImageType
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
Image
]
_ImageTypeJIT
=
torch
.
Tensor
_TensorImageType
=
Union
[
torch
.
Tensor
,
Image
]
_TensorImageTypeJIT
=
torch
.
Tensor
torchvision/datapoints/_video.py
View file @
641fdd9f
...
...
@@ -35,9 +35,3 @@ class Video(Datapoint):
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
()
_VideoType
=
Union
[
torch
.
Tensor
,
Video
]
_VideoTypeJIT
=
torch
.
Tensor
_TensorVideoType
=
Union
[
torch
.
Tensor
,
Video
]
_TensorVideoTypeJIT
=
torch
.
Tensor
torchvision/prototype/transforms/_augment.py
View file @
641fdd9f
...
...
@@ -26,15 +26,15 @@ class SimpleCopyPaste(Transform):
def
_copy_paste
(
self
,
image
:
datapoints
.
_Tensor
Image
Type
,
image
:
Union
[
torch
.
Tensor
,
datapoints
.
Image
]
,
target
:
Dict
[
str
,
Any
],
paste_image
:
datapoints
.
_Tensor
Image
Type
,
paste_image
:
Union
[
torch
.
Tensor
,
datapoints
.
Image
]
,
paste_target
:
Dict
[
str
,
Any
],
random_selection
:
torch
.
Tensor
,
blending
:
bool
,
resize_interpolation
:
F
.
InterpolationMode
,
antialias
:
Optional
[
bool
],
)
->
Tuple
[
datapoints
.
_TensorImageType
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
Any
]]:
paste_masks
=
paste_target
[
"masks"
].
wrap_like
(
paste_target
[
"masks"
],
paste_target
[
"masks"
][
random_selection
])
paste_boxes
=
paste_target
[
"boxes"
].
wrap_like
(
paste_target
[
"boxes"
],
paste_target
[
"boxes"
][
random_selection
])
...
...
@@ -106,7 +106,7 @@ class SimpleCopyPaste(Transform):
def
_extract_image_targets
(
self
,
flat_sample
:
List
[
Any
]
)
->
Tuple
[
List
[
datapoints
.
_Tensor
Image
Type
],
List
[
Dict
[
str
,
Any
]]]:
)
->
Tuple
[
List
[
Union
[
torch
.
Tensor
,
datapoints
.
Image
]
],
List
[
Dict
[
str
,
Any
]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
...
...
@@ -137,7 +137,7 @@ class SimpleCopyPaste(Transform):
def
_insert_outputs
(
self
,
flat_sample
:
List
[
Any
],
output_images
:
List
[
datapoints
.
_TensorImageType
],
output_images
:
List
[
torch
.
Tensor
],
output_targets
:
List
[
Dict
[
str
,
Any
]],
)
->
None
:
c0
,
c1
,
c2
,
c3
=
0
,
0
,
0
,
0
...
...
torchvision/prototype/transforms/_geometry.py
View file @
641fdd9f
...
...
@@ -6,7 +6,7 @@ import torch
from
torchvision
import
datapoints
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2._utils
import
_FillType
,
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2.utils
import
get_bounding_boxes
,
has_any
,
is_simple_tensor
,
query_size
...
...
@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
padding_mode
:
str
=
"constant"
,
)
->
None
:
super
().
__init__
()
...
...
torchvision/prototype/transforms/_misc.py
View file @
641fdd9f
...
...
@@ -39,9 +39,7 @@ class PermuteDimensions(Transform):
)
self
.
dims
=
dims
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_TensorImageType
,
datapoints
.
_TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
torch
.
Tensor
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
dims
=
self
.
dims
[
type
(
inpt
)]
if
dims
is
None
:
return
inpt
.
as_subclass
(
torch
.
Tensor
)
...
...
@@ -63,9 +61,7 @@ class TransposeDimensions(Transform):
)
self
.
dims
=
dims
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_TensorImageType
,
datapoints
.
_TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
torch
.
Tensor
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
dims
=
self
.
dims
[
type
(
inpt
)]
if
dims
is
None
:
return
inpt
.
as_subclass
(
torch
.
Tensor
)
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
641fdd9f
...
...
@@ -10,17 +10,21 @@ from torchvision.transforms import _functional_tensor as _FT
from
torchvision.transforms.v2
import
AutoAugmentPolicy
,
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._meta
import
get_size
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
from
._utils
import
_get_fill
,
_setup_fill_arg
from
.utils
import
check_type
,
is_simple_tensor
ImageOrVideo
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
Video
]
class
_AutoAugmentBase
(
Transform
):
def
__init__
(
self
,
*
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -35,7 +39,7 @@ class _AutoAugmentBase(Transform):
self
,
inputs
:
Any
,
unsupported_types
:
Tuple
[
Type
,
...]
=
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
),
)
->
Tuple
[
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
]:
)
->
Tuple
[
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
ImageOrVideo
]:
flat_inputs
,
spec
=
tree_flatten
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
...
...
@@ -68,7 +72,7 @@ class _AutoAugmentBase(Transform):
def
_unflatten_and_insert_image_or_video
(
self
,
flat_inputs_with_spec
:
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
image_or_video
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
,
image_or_video
:
ImageOrVideo
,
)
->
Any
:
flat_inputs
,
spec
,
idx
=
flat_inputs_with_spec
flat_inputs
[
idx
]
=
image_or_video
...
...
@@ -76,12 +80,12 @@ class _AutoAugmentBase(Transform):
def
_apply_image_or_video_transform
(
self
,
image
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
,
image
:
ImageOrVideo
,
transform_id
:
str
,
magnitude
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
],
fill
:
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillTypeJIT
],
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]
:
fill
:
Dict
[
Union
[
Type
,
str
],
_FillTypeJIT
],
)
->
ImageOrVideo
:
fill_
=
_get_fill
(
fill
,
type
(
image
))
if
transform_id
==
"Identity"
:
...
...
@@ -214,7 +218,7 @@ class AutoAugment(_AutoAugmentBase):
self
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
policy
=
policy
...
...
@@ -394,7 +398,7 @@ class RandAugment(_AutoAugmentBase):
magnitude
:
int
=
9
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
num_ops
=
num_ops
...
...
@@ -467,7 +471,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self
,
num_magnitude_bins
:
int
=
31
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
):
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
num_magnitude_bins
=
num_magnitude_bins
...
...
@@ -550,7 +554,7 @@ class AugMix(_AutoAugmentBase):
alpha
:
float
=
1.0
,
all_ops
:
bool
=
True
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
None
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
None
,
)
->
None
:
super
().
__init__
(
interpolation
=
interpolation
,
fill
=
fill
)
self
.
_PARAMETER_MAX
=
10
...
...
torchvision/transforms/v2/_color.py
View file @
641fdd9f
...
...
@@ -261,9 +261,7 @@ class RandomPhotometricDistort(Transform):
params
[
"channel_permutation"
]
=
torch
.
randperm
(
num_channels
)
if
torch
.
rand
(
1
)
<
self
.
p
else
None
return
params
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"brightness_factor"
]
is
not
None
:
inpt
=
F
.
adjust_brightness
(
inpt
,
brightness_factor
=
params
[
"brightness_factor"
])
if
params
[
"contrast_factor"
]
is
not
None
and
params
[
"contrast_before"
]:
...
...
torchvision/transforms/v2/_geometry.py
View file @
641fdd9f
...
...
@@ -11,6 +11,7 @@ from torchvision.ops.boxes import box_iou
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.functional._utils
import
_FillType
from
._transform
import
_RandomApplyTransform
from
._utils
import
(
...
...
@@ -311,9 +312,6 @@ class RandomResizedCrop(Transform):
)
ImageOrVideoTypeJIT
=
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
class
FiveCrop
(
Transform
):
"""[BETA] Crop the image or video into four corners and the central crop.
...
...
@@ -459,7 +457,7 @@ class Pad(Transform):
def
__init__
(
self
,
padding
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -514,7 +512,7 @@ class RandomZoomOut(_RandomApplyTransform):
def
__init__
(
self
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
side_range
:
Sequence
[
float
]
=
(
1.0
,
4.0
),
p
:
float
=
0.5
,
)
->
None
:
...
...
@@ -592,7 +590,7 @@ class RandomRotation(Transform):
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
)
->
None
:
super
().
__init__
()
self
.
degrees
=
_setup_angle
(
degrees
,
name
=
"degrees"
,
req_sizes
=
(
2
,))
...
...
@@ -674,7 +672,7 @@ class RandomAffine(Transform):
scale
:
Optional
[
Sequence
[
float
]]
=
None
,
shear
:
Optional
[
Union
[
int
,
float
,
Sequence
[
float
]]]
=
None
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -812,7 +810,7 @@ class RandomCrop(Transform):
size
:
Union
[
int
,
Sequence
[
int
]],
padding
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
,
pad_if_needed
:
bool
=
False
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
padding_mode
:
Literal
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]
=
"constant"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -931,7 +929,7 @@ class RandomPerspective(_RandomApplyTransform):
distortion_scale
:
float
=
0.5
,
p
:
float
=
0.5
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
)
->
None
:
super
().
__init__
(
p
=
p
)
...
...
@@ -1033,7 +1031,7 @@ class ElasticTransform(Transform):
alpha
:
Union
[
float
,
Sequence
[
float
]]
=
50.0
,
sigma
:
Union
[
float
,
Sequence
[
float
]]
=
5.0
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
Union
[
datapoints
.
_FillType
,
Dict
[
Union
[
Type
,
str
],
datapoints
.
_FillType
]]
=
0
,
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
)
->
None
:
super
().
__init__
()
self
.
alpha
=
_setup_float_or_seq
(
alpha
,
"alpha"
,
2
)
...
...
torchvision/transforms/v2/_misc.py
View file @
641fdd9f
...
...
@@ -169,9 +169,7 @@ class Normalize(Transform):
if
has_any
(
sample
,
PIL
.
Image
.
Image
):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images."
)
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_TensorImageType
,
datapoints
.
_TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
normalize
(
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
...
...
torchvision/transforms/v2/_temporal.py
View file @
641fdd9f
from
typing
import
Any
,
Dict
import
torch
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
...
...
@@ -25,5 +24,5 @@ class UniformTemporalSubsample(Transform):
super
().
__init__
()
self
.
num_samples
=
num_samples
def
_transform
(
self
,
inpt
:
datapoints
.
_VideoType
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
_VideoType
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
uniform_temporal_subsample
(
inpt
,
self
.
num_samples
)
torchvision/transforms/v2/_utils.py
View file @
641fdd9f
...
...
@@ -5,9 +5,8 @@ from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
import
torch
from
torchvision
import
datapoints
from
torchvision.datapoints._datapoint
import
_FillType
,
_FillTypeJIT
from
torchvision.transforms.transforms
import
_check_sequence_input
,
_setup_angle
,
_setup_size
# noqa: F401
from
torchvision.transforms.v2.functional._utils
import
_FillType
,
_FillTypeJIT
def
_setup_float_or_seq
(
arg
:
Union
[
float
,
Sequence
[
float
]],
name
:
str
,
req_size
:
int
=
2
)
->
Sequence
[
float
]:
...
...
@@ -36,7 +35,7 @@ def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -
raise
TypeError
(
"Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed."
)
def
_convert_fill_arg
(
fill
:
datapoints
.
_FillType
)
->
datapoints
.
_FillTypeJIT
:
def
_convert_fill_arg
(
fill
:
_FillType
)
->
_FillTypeJIT
:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
...
...
torchvision/transforms/v2/functional/_augment.py
View file @
641fdd9f
from
typing
import
Union
import
PIL.Image
import
torch
...
...
@@ -12,14 +10,14 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
warn_passthrough
=
True
)
def
erase
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
,
inpt
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
...
...
torchvision/transforms/v2/functional/_color.py
View file @
641fdd9f
from
typing
import
List
,
Union
from
typing
import
List
import
PIL.Image
import
torch
...
...
@@ -16,9 +16,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
)
def
rgb_to_grayscale
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]:
def
rgb_to_grayscale
(
inpt
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
rgb_to_grayscale_image_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
...
...
@@ -73,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_brightness
(
inpt
:
datapoints
.
_InputTypeJIT
,
brightness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_brightness
(
inpt
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
...
...
@@ -110,7 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_saturation
(
inpt
:
datapoints
.
_InputTypeJIT
,
saturation_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_saturation
(
inpt
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_saturation_image_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
...
...
@@ -149,7 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_contrast
(
inpt
:
datapoints
.
_InputTypeJIT
,
contrast_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_contrast
(
inpt
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
...
...
@@ -188,7 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_sharpness
(
inpt
:
datapoints
.
_InputTypeJIT
,
sharpness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_sharpness
(
inpt
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_sharpness_image_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
...
...
@@ -261,7 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_hue
(
inpt
:
datapoints
.
_InputTypeJIT
,
hue_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_hue
(
inpt
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
...
...
@@ -373,7 +371,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_gamma
(
inpt
:
datapoints
.
_InputTypeJIT
,
gamma
:
float
,
gain
:
float
=
1
)
->
datapoints
.
_InputTypeJIT
:
def
adjust_gamma
(
inpt
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
...
...
@@ -413,7 +411,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
posterize
(
inpt
:
datapoints
.
_InputTypeJIT
,
bits
:
int
)
->
datapoints
.
_InputTypeJIT
:
def
posterize
(
inpt
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
...
...
@@ -447,7 +445,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
solarize
(
inpt
:
datapoints
.
_InputTypeJIT
,
threshold
:
float
)
->
datapoints
.
_InputTypeJIT
:
def
solarize
(
inpt
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
...
...
@@ -475,7 +473,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
autocontrast
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
autocontrast
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
autocontrast_image_tensor
(
inpt
)
...
...
@@ -525,7 +523,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
equalize
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
equalize
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
equalize_image_tensor
(
inpt
)
...
...
@@ -615,7 +613,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
invert
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
invert
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
invert_image_tensor
(
inpt
)
...
...
@@ -646,7 +644,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
permute_channels
(
inpt
:
datapoints
.
_InputTypeJIT
,
permutation
:
List
[
int
])
->
datapoints
.
_InputTypeJIT
:
def
permute_channels
(
inpt
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
"""Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
...
...
torchvision/transforms/v2/functional/_deprecated.py
View file @
641fdd9f
import
warnings
from
typing
import
Any
,
List
,
Union
from
typing
import
Any
,
List
import
torch
from
torchvision
import
datapoints
from
torchvision.transforms
import
functional
as
_F
...
...
@@ -16,7 +15,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return
_F
.
to_tensor
(
inpt
)
def
get_image_size
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
)
->
List
[
int
]:
def
get_image_size
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
warnings
.
warn
(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`."
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
641fdd9f
...
...
@@ -25,7 +25,13 @@ from torchvision.utils import _log_api_usage_once
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
from
._utils
import
(
_FillTypeJIT
,
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
,
)
def
_check_interpolation
(
interpolation
:
Union
[
InterpolationMode
,
int
])
->
InterpolationMode
:
...
...
@@ -39,7 +45,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
return
interpolation
def
horizontal_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
horizontal_flip
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
horizontal_flip_image_tensor
(
inpt
)
...
...
@@ -95,7 +101,7 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return
horizontal_flip_image_tensor
(
video
)
def
vertical_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
vertical_flip
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
vertical_flip_image_tensor
(
inpt
)
...
...
@@ -171,12 +177,12 @@ def _compute_resized_output_size(
def
resize
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
resize_image_tensor
(
inpt
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
...
...
@@ -364,15 +370,15 @@ def resize_video(
def
affine
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
affine_image_tensor
(
inpt
,
...
...
@@ -549,9 +555,7 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
return
int
(
size
[
0
]),
int
(
size
[
1
])
# w, h
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
datapoints
.
_FillTypeJIT
)
->
torch
.
Tensor
:
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
_FillTypeJIT
)
->
torch
.
Tensor
:
# We are using context knowledge that grid should have float dtype
fp
=
img
.
dtype
==
grid
.
dtype
...
...
@@ -592,7 +596,7 @@ def _assert_grid_transform_inputs(
image
:
torch
.
Tensor
,
matrix
:
Optional
[
List
[
float
]],
interpolation
:
str
,
fill
:
datapoints
.
_FillTypeJIT
,
fill
:
_FillTypeJIT
,
supported_interpolation_modes
:
List
[
str
],
coeffs
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
...
...
@@ -657,7 +661,7 @@ def affine_image_tensor(
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -709,7 +713,7 @@ def affine_image_pil(
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
PIL
.
Image
.
Image
:
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -868,7 +872,7 @@ def affine_mask(
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
...
...
@@ -901,7 +905,7 @@ def _affine_mask_dispatch(
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoints
.
Mask
:
...
...
@@ -925,7 +929,7 @@ def affine_video(
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
return
affine_image_tensor
(
...
...
@@ -941,13 +945,13 @@ def affine_video(
def
rotate
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
rotate_image_tensor
(
inpt
,
angle
=
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
...
...
@@ -967,7 +971,7 @@ def rotate_image_tensor(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -1012,7 +1016,7 @@ def rotate_image_pil(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
PIL
.
Image
.
Image
:
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -1068,7 +1072,7 @@ def rotate_mask(
angle
:
float
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
...
...
@@ -1097,7 +1101,7 @@ def _rotate_mask_dispatch(
angle
:
float
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
**
kwargs
,
)
->
datapoints
.
Mask
:
output
=
rotate_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
angle
=
angle
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
...
...
@@ -1111,17 +1115,17 @@ def rotate_video(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
return
rotate_image_tensor
(
video
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
def
pad
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
padding
:
List
[
int
],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
pad_image_tensor
(
inpt
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
...
...
@@ -1336,7 +1340,7 @@ def pad_video(
return
pad_image_tensor
(
video
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
def
crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
_InputTypeJIT
:
def
crop
(
inpt
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
crop_image_tensor
(
inpt
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
...
...
@@ -1423,13 +1427,13 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
def
perspective
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
perspective_image_tensor
(
inpt
,
...
...
@@ -1507,7 +1511,7 @@ def perspective_image_tensor(
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
...
...
@@ -1554,7 +1558,7 @@ def perspective_image_pil(
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BICUBIC
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
PIL
.
Image
.
Image
:
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
...
...
@@ -1679,7 +1683,7 @@ def perspective_mask(
mask
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
...
...
@@ -1703,7 +1707,7 @@ def _perspective_mask_dispatch(
inpt
:
datapoints
.
Mask
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoints
.
Mask
:
...
...
@@ -1723,7 +1727,7 @@ def perspective_video(
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
return
perspective_image_tensor
(
...
...
@@ -1732,11 +1736,11 @@ def perspective_video(
def
elastic
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
elastic_image_tensor
(
inpt
,
displacement
=
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
...
...
@@ -1755,7 +1759,7 @@ def elastic_image_tensor(
image
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
...
...
@@ -1812,7 +1816,7 @@ def elastic_image_pil(
image
:
PIL
.
Image
.
Image
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
PIL
.
Image
.
Image
:
t_img
=
pil_to_tensor
(
image
)
output
=
elastic_image_tensor
(
t_img
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
...
...
@@ -1895,7 +1899,7 @@ def _elastic_bounding_boxes_dispatch(
def
elastic_mask
(
mask
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
...
...
@@ -1913,7 +1917,7 @@ def elastic_mask(
@
_register_kernel_internal
(
elastic
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_elastic_mask_dispatch
(
inpt
:
datapoints
.
Mask
,
displacement
:
torch
.
Tensor
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
**
kwargs
inpt
:
datapoints
.
Mask
,
displacement
:
torch
.
Tensor
,
fill
:
_FillTypeJIT
=
None
,
**
kwargs
)
->
datapoints
.
Mask
:
output
=
elastic_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
displacement
=
displacement
,
fill
=
fill
)
return
datapoints
.
Mask
.
wrap_like
(
inpt
,
output
)
...
...
@@ -1924,12 +1928,12 @@ def elastic_video(
video
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
return
elastic_image_tensor
(
video
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
def
center_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
output_size
:
List
[
int
])
->
datapoints
.
_InputTypeJIT
:
def
center_crop
(
inpt
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
center_crop_image_tensor
(
inpt
,
output_size
=
output_size
)
...
...
@@ -2049,7 +2053,7 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
def
resized_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
...
...
@@ -2057,7 +2061,7 @@ def resized_crop(
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
resized_crop_image_tensor
(
inpt
,
...
...
@@ -2201,14 +2205,8 @@ def resized_crop_video(
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
five_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
size
:
List
[
int
]
)
->
Tuple
[
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
]:
inpt
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
torch
.
jit
.
is_scripting
():
return
five_crop_image_tensor
(
inpt
,
size
=
size
)
...
...
@@ -2280,18 +2278,18 @@ def five_crop_video(
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
ten_crop
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
inpt
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
]:
if
torch
.
jit
.
is_scripting
():
return
ten_crop_image_tensor
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
641fdd9f
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
import
PIL.Image
import
torch
...
...
@@ -12,7 +12,7 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
)
->
List
[
int
]:
def
get_dimensions
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
if
torch
.
jit
.
is_scripting
():
return
get_dimensions_image_tensor
(
inpt
)
...
...
@@ -45,7 +45,7 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_channels
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
)
->
int
:
def
get_num_channels
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
return
get_num_channels_image_tensor
(
inpt
)
...
...
@@ -81,7 +81,7 @@ def get_num_channels_video(video: torch.Tensor) -> int:
get_image_num_channels
=
get_num_channels
def
get_size
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
List
[
int
]:
def
get_size
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
if
torch
.
jit
.
is_scripting
():
return
get_size_image_tensor
(
inpt
)
...
...
@@ -124,7 +124,7 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
@
_register_unsupported_type
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_frames
(
inpt
:
datapoints
.
_VideoTypeJIT
)
->
int
:
def
get_num_frames
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
return
get_num_frames_video
(
inpt
)
...
...
@@ -201,11 +201,11 @@ def _convert_format_bounding_boxes(
def
convert_format_bounding_boxes
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
old_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
inplace
:
bool
=
False
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
...
...
@@ -252,10 +252,10 @@ def _clamp_bounding_boxes(
def
clamp_bounding_boxes
(
inpt
:
datapoints
.
_InputTypeJIT
,
inpt
:
torch
.
Tensor
,
format
:
Optional
[
BoundingBoxFormat
]
=
None
,
canvas_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_boxes
)
...
...
torchvision/transforms/v2/functional/_misc.py
View file @
641fdd9f
import
math
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
import
PIL.Image
import
torch
...
...
@@ -17,7 +17,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
normalize
(
inpt
:
Union
[
datapoints
.
_TensorImageTypeJIT
,
datapoints
.
_TensorVideoTypeJIT
]
,
inpt
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
,
...
...
@@ -74,9 +74,7 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
gaussian_blur
(
inpt
:
datapoints
.
_InputTypeJIT
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
datapoints
.
_InputTypeJIT
:
def
gaussian_blur
(
inpt
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
...
...
@@ -185,9 +183,7 @@ def gaussian_blur_video(
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
to_dtype
(
inpt
:
datapoints
.
_InputTypeJIT
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
datapoints
.
_InputTypeJIT
:
def
to_dtype
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
to_dtype_image_tensor
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
...
...
@@ -278,8 +274,6 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale:
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_to_dtype_tensor_dispatch
(
inpt
:
datapoints
.
_InputTypeJIT
,
dtype
:
torch
.
dtype
,
scale
:
bool
=
False
)
->
datapoints
.
_InputTypeJIT
:
def
_to_dtype_tensor_dispatch
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
# We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type
return
inpt
.
to
(
dtype
)
torchvision/transforms/v2/functional/_temporal.py
View file @
641fdd9f
...
...
@@ -11,7 +11,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@
_register_explicit_noop
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
uniform_temporal_subsample
(
inpt
:
datapoints
.
_VideoTypeJIT
,
num_samples
:
int
)
->
datapoints
.
_VideoTypeJIT
:
def
uniform_temporal_subsample
(
inpt
:
torch
.
Tensor
,
num_samples
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
=
num_samples
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment