Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d23a6e16
Unverified
Commit
d23a6e16
authored
May 07, 2024
by
Nicolas Hug
Committed by
GitHub
May 07, 2024
Browse files
More mypy fixes/ignores (#8412)
parent
f766d7ac
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
10 deletions
+10
-10
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+8
-8
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+1
-1
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+1
-1
No files found.
torchvision/transforms/v2/_auto_augment.py
View file @
d23a6e16
import
math
import
math
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
PIL.Image
import
PIL.Image
import
torch
import
torch
...
@@ -94,6 +94,8 @@ class _AutoAugmentBase(Transform):
...
@@ -94,6 +94,8 @@ class _AutoAugmentBase(Transform):
interpolation
:
Union
[
InterpolationMode
,
int
],
interpolation
:
Union
[
InterpolationMode
,
int
],
fill
:
Dict
[
Union
[
Type
,
str
],
_FillTypeJIT
],
fill
:
Dict
[
Union
[
Type
,
str
],
_FillTypeJIT
],
)
->
ImageOrVideo
:
)
->
ImageOrVideo
:
# Note: this cast is wrong and is only here to make mypy happy (it disagrees with torchscript)
image
=
cast
(
torch
.
Tensor
,
image
)
fill_
=
_get_fill
(
fill
,
type
(
image
))
fill_
=
_get_fill
(
fill
,
type
(
image
))
if
transform_id
==
"Identity"
:
if
transform_id
==
"Identity"
:
...
@@ -322,7 +324,7 @@ class AutoAugment(_AutoAugmentBase):
...
@@ -322,7 +324,7 @@ class AutoAugment(_AutoAugmentBase):
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
flat_inputs_with_spec
,
image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
flat_inputs_with_spec
,
image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
height
,
width
=
get_size
(
image_or_video
)
height
,
width
=
get_size
(
image_or_video
)
# type: ignore[arg-type]
policy
=
self
.
_policies
[
int
(
torch
.
randint
(
len
(
self
.
_policies
),
()))]
policy
=
self
.
_policies
[
int
(
torch
.
randint
(
len
(
self
.
_policies
),
()))]
...
@@ -411,7 +413,7 @@ class RandAugment(_AutoAugmentBase):
...
@@ -411,7 +413,7 @@ class RandAugment(_AutoAugmentBase):
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
flat_inputs_with_spec
,
image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
flat_inputs_with_spec
,
image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
height
,
width
=
get_size
(
image_or_video
)
height
,
width
=
get_size
(
image_or_video
)
# type: ignore[arg-type]
for
_
in
range
(
self
.
num_ops
):
for
_
in
range
(
self
.
num_ops
):
transform_id
,
(
magnitudes_fn
,
signed
)
=
self
.
_get_random_item
(
self
.
_AUGMENTATION_SPACE
)
transform_id
,
(
magnitudes_fn
,
signed
)
=
self
.
_get_random_item
(
self
.
_AUGMENTATION_SPACE
)
...
@@ -480,7 +482,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
...
@@ -480,7 +482,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
flat_inputs_with_spec
,
image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
flat_inputs_with_spec
,
image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
height
,
width
=
get_size
(
image_or_video
)
height
,
width
=
get_size
(
image_or_video
)
# type: ignore[arg-type]
transform_id
,
(
magnitudes_fn
,
signed
)
=
self
.
_get_random_item
(
self
.
_AUGMENTATION_SPACE
)
transform_id
,
(
magnitudes_fn
,
signed
)
=
self
.
_get_random_item
(
self
.
_AUGMENTATION_SPACE
)
...
@@ -572,7 +574,7 @@ class AugMix(_AutoAugmentBase):
...
@@ -572,7 +574,7 @@ class AugMix(_AutoAugmentBase):
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
flat_inputs_with_spec
,
orig_image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
flat_inputs_with_spec
,
orig_image_or_video
=
self
.
_flatten_and_extract_image_or_video
(
inputs
)
height
,
width
=
get_size
(
orig_image_or_video
)
height
,
width
=
get_size
(
orig_image_or_video
)
# type: ignore[arg-type]
if
isinstance
(
orig_image_or_video
,
torch
.
Tensor
):
if
isinstance
(
orig_image_or_video
,
torch
.
Tensor
):
image_or_video
=
orig_image_or_video
image_or_video
=
orig_image_or_video
...
@@ -613,9 +615,7 @@ class AugMix(_AutoAugmentBase):
...
@@ -613,9 +615,7 @@ class AugMix(_AutoAugmentBase):
else
:
else
:
magnitude
=
0.0
magnitude
=
0.0
aug
=
self
.
_apply_image_or_video_transform
(
aug
=
self
.
_apply_image_or_video_transform
(
aug
,
transform_id
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
self
.
_fill
)
# type: ignore[assignment]
aug
,
transform_id
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
self
.
_fill
)
mix
.
add_
(
combined_weights
[:,
i
].
reshape
(
batch_dims
)
*
aug
)
mix
.
add_
(
combined_weights
[:,
i
].
reshape
(
batch_dims
)
*
aug
)
mix
=
mix
.
reshape
(
orig_dims
).
to
(
dtype
=
image_or_video
.
dtype
)
mix
=
mix
.
reshape
(
orig_dims
).
to
(
dtype
=
image_or_video
.
dtype
)
...
...
torchvision/transforms/v2/functional/_color.py
View file @
d23a6e16
...
@@ -730,7 +730,7 @@ def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch
...
@@ -730,7 +730,7 @@ def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch
@
_register_kernel_internal
(
permute_channels
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
permute_channels
,
PIL
.
Image
.
Image
)
def
_permute_channels_image_pil
(
image
:
PIL
.
Image
.
Image
,
permutation
:
List
[
int
])
->
PIL
.
Image
:
def
_permute_channels_image_pil
(
image
:
PIL
.
Image
.
Image
,
permutation
:
List
[
int
])
->
PIL
.
Image
.
Image
:
return
to_pil_image
(
permute_channels_image
(
pil_to_tensor
(
image
),
permutation
=
permutation
))
return
to_pil_image
(
permute_channels_image
(
pil_to_tensor
(
image
),
permutation
=
permutation
))
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
d23a6e16
...
@@ -113,7 +113,7 @@ def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
...
@@ -113,7 +113,7 @@ def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
vertical_flip
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
vertical_flip
,
PIL
.
Image
.
Image
)
def
_vertical_flip_image_pil
(
image
:
PIL
.
Image
)
->
PIL
.
Image
:
def
_vertical_flip_image_pil
(
image
:
PIL
.
Image
.
Image
)
->
PIL
.
Image
.
Image
:
return
_FP
.
vflip
(
image
)
return
_FP
.
vflip
(
image
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment