Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d7e5b6a1
Unverified
Commit
d7e5b6a1
authored
Jan 20, 2023
by
Nicolas Hug
Committed by
GitHub
Jan 20, 2023
Browse files
Let Normalize() and RandomPhotometricDistort return datapoints instead of tensors (#7113)
parent
c06d52b1
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
22 additions
and
31 deletions
+22
-31
test/prototype_transforms_dispatcher_infos.py
test/prototype_transforms_dispatcher_infos.py
+0
-1
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+1
-13
torchvision/prototype/datapoints/_image.py
torchvision/prototype/datapoints/_image.py
+4
-0
torchvision/prototype/datapoints/_video.py
torchvision/prototype/datapoints/_video.py
+4
-0
torchvision/prototype/transforms/_color.py
torchvision/prototype/transforms/_color.py
+5
-5
torchvision/prototype/transforms/functional/_misc.py
torchvision/prototype/transforms/functional/_misc.py
+8
-12
No files found.
test/prototype_transforms_dispatcher_infos.py
View file @
d7e5b6a1
...
...
@@ -426,7 +426,6 @@ DISPATCHER_INFOS = [
datapoints
.
Video
:
F
.
normalize_video
,
},
test_marks
=
[
skip_dispatch_feature
,
xfail_jit_python_scalar_arg
(
"mean"
),
xfail_jit_python_scalar_arg
(
"std"
),
],
...
...
test/test_prototype_transforms_functional.py
View file @
d7e5b6a1
...
...
@@ -13,7 +13,7 @@ import torch
import
torchvision.prototype.transforms.utils
from
common_utils
import
cache
,
cpu_and_gpu
,
needs_cuda
,
set_rng_seed
from
prototype_common_utils
import
assert_close
,
make_bounding_boxes
,
make_image
,
parametrized_error_message
from
prototype_common_utils
import
assert_close
,
make_bounding_boxes
,
parametrized_error_message
from
prototype_transforms_dispatcher_infos
import
DISPATCHER_INFOS
from
prototype_transforms_kernel_infos
import
KERNEL_INFOS
from
torch.utils._pytree
import
tree_map
...
...
@@ -1185,18 +1185,6 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize,
torch
.
testing
.
assert_close
(
out
,
true_out
,
rtol
=
0.0
,
atol
=
1.0
,
msg
=
f
"
{
ksize
}
,
{
sigma
}
"
)
def
test_normalize_output_type
():
inpt
=
torch
.
rand
(
1
,
3
,
32
,
32
)
output
=
F
.
normalize
(
inpt
,
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
1.0
,
1.0
,
1.0
])
assert
type
(
output
)
is
torch
.
Tensor
torch
.
testing
.
assert_close
(
inpt
-
0.5
,
output
)
inpt
=
make_image
(
color_space
=
datapoints
.
ColorSpace
.
RGB
)
output
=
F
.
normalize
(
inpt
,
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
1.0
,
1.0
,
1.0
])
assert
type
(
output
)
is
torch
.
Tensor
torch
.
testing
.
assert_close
(
inpt
-
0.5
,
output
)
@
pytest
.
mark
.
parametrize
(
"inpt"
,
[
...
...
torchvision/prototype/datapoints/_image.py
View file @
d7e5b6a1
...
...
@@ -289,6 +289,10 @@ class Image(Datapoint):
)
return
Image
.
wrap_like
(
self
,
output
)
def
normalize
(
self
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
Image
:
output
=
self
.
_F
.
normalize_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
return
Image
.
wrap_like
(
self
,
output
)
ImageType
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
Image
]
ImageTypeJIT
=
torch
.
Tensor
...
...
torchvision/prototype/datapoints/_video.py
View file @
d7e5b6a1
...
...
@@ -241,6 +241,10 @@ class Video(Datapoint):
output
=
self
.
_F
.
gaussian_blur_video
(
self
.
as_subclass
(
torch
.
Tensor
),
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
Video
.
wrap_like
(
self
,
output
)
def
normalize
(
self
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
Video
:
output
=
self
.
_F
.
normalize_video
(
self
.
as_subclass
(
torch
.
Tensor
),
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
return
Video
.
wrap_like
(
self
,
output
)
VideoType
=
Union
[
torch
.
Tensor
,
Video
]
VideoTypeJIT
=
torch
.
Tensor
...
...
torchvision/prototype/transforms/_color.py
View file @
d7e5b6a1
...
...
@@ -82,6 +82,7 @@ class ColorJitter(Transform):
return
output
# TODO: This class seems to be untested
class
RandomPhotometricDistort
(
Transform
):
_transformed_types
=
(
datapoints
.
Image
,
...
...
@@ -119,15 +120,14 @@ class RandomPhotometricDistort(Transform):
def
_permute_channels
(
self
,
inpt
:
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
],
permutation
:
torch
.
Tensor
)
->
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]:
if
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
orig_inpt
=
inpt
if
isinstance
(
orig_inpt
,
PIL
.
Image
.
Image
):
inpt
=
F
.
pil_to_tensor
(
inpt
)
output
=
inpt
[...,
permutation
,
:,
:]
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
output
=
inpt
.
wrap_like
(
inpt
,
output
,
color_space
=
datapoints
.
ColorSpace
.
OTHER
)
# type: ignore[arg-type]
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
if
isinstance
(
orig_inpt
,
PIL
.
Image
.
Image
):
output
=
F
.
to_image_pil
(
output
)
return
output
...
...
torchvision/prototype/transforms/functional/_misc.py
View file @
d7e5b6a1
...
...
@@ -60,19 +60,15 @@ def normalize(
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
normalize
)
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
elif
not
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
normalize_image_tensor
(
inpt
,
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
return
inpt
.
normalize
(
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f
"but got
{
type
(
inpt
)
}
instead."
f
"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f
"but got
{
type
(
inpt
)
}
instead."
)
# Image or Video type should not be retained after normalization due to unknown data range
# Thus we return Tensor for input Image
return
normalize_image_tensor
(
inpt
,
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
def
_get_gaussian_kernel1d
(
kernel_size
:
int
,
sigma
:
float
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
lim
=
(
kernel_size
-
1
)
/
(
2.0
*
math
.
sqrt
(
2.0
)
*
sigma
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment