Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
020eafe1
Unverified
Commit
020eafe1
authored
Aug 24, 2022
by
vfdev
Committed by
GitHub
Aug 24, 2022
Browse files
Removed F.label_to_one_hot and added tests for LabelToOneHot (#6483)
parent
b6feccbc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
21 deletions
+21
-21
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+11
-0
torchvision/prototype/transforms/_type_conversion.py
torchvision/prototype/transforms/_type_conversion.py
+9
-9
torchvision/prototype/transforms/functional/__init__.py
torchvision/prototype/transforms/functional/__init__.py
+1
-7
torchvision/prototype/transforms/functional/_type_conversion.py
...ision/prototype/transforms/functional/_type_conversion.py
+0
-5
No files found.
test/test_prototype_transforms.py
View file @
020eafe1
...
@@ -1593,3 +1593,14 @@ class TestLinearTransformation:
...
@@ -1593,3 +1593,14 @@ class TestLinearTransformation:
assert
isinstance
(
output
,
torch
.
Tensor
)
assert
isinstance
(
output
,
torch
.
Tensor
)
assert
output
.
unique
()
==
3
*
8
*
8
assert
output
.
unique
()
==
3
*
8
*
8
assert
output
.
dtype
==
inpt
.
dtype
assert
output
.
dtype
==
inpt
.
dtype
class
TestLabelToOneHot
:
def
test__transform
(
self
):
categories
=
[
"apple"
,
"pear"
,
"pineapple"
]
labels
=
features
.
Label
(
torch
.
tensor
([
0
,
1
,
2
,
1
]),
categories
=
categories
)
transform
=
transforms
.
LabelToOneHot
()
ohe_labels
=
transform
(
labels
)
assert
isinstance
(
ohe_labels
,
features
.
OneHotLabel
)
assert
ohe_labels
.
shape
==
(
4
,
3
)
assert
ohe_labels
.
categories
==
labels
.
categories
==
categories
torchvision/prototype/transforms/_type_conversion.py
View file @
020eafe1
...
@@ -4,6 +4,7 @@ import numpy as np
...
@@ -4,6 +4,7 @@ import numpy as np
import
PIL.Image
import
PIL.Image
import
torch
import
torch
from
torch.nn.functional
import
one_hot
from
torchvision.prototype
import
features
from
torchvision.prototype
import
features
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
...
@@ -20,19 +21,18 @@ class DecodeImage(Transform):
...
@@ -20,19 +21,18 @@ class DecodeImage(Transform):
class
LabelToOneHot
(
Transform
):
class
LabelToOneHot
(
Transform
):
_transformed_types
=
(
features
.
Label
,)
def
__init__
(
self
,
num_categories
:
int
=
-
1
):
def
__init__
(
self
,
num_categories
:
int
=
-
1
):
super
().
__init__
()
super
().
__init__
()
self
.
num_categories
=
num_categories
self
.
num_categories
=
num_categories
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
features
.
Label
,
params
:
Dict
[
str
,
Any
])
->
features
.
OneHotLabel
:
if
isinstance
(
inpt
,
features
.
Label
):
num_categories
=
self
.
num_categories
num_categories
=
self
.
num_categories
if
num_categories
==
-
1
and
inpt
.
categories
is
not
None
:
if
num_categories
==
-
1
and
inpt
.
categories
is
not
None
:
num_categories
=
len
(
inpt
.
categories
)
num_categories
=
len
(
inpt
.
categories
)
output
=
one_hot
(
inpt
,
num_classes
=
num_categories
)
output
=
F
.
label_to_one_hot
(
inpt
,
num_categories
=
num_categories
)
return
features
.
OneHotLabel
(
output
,
categories
=
inpt
.
categories
)
return
features
.
OneHotLabel
(
output
,
categories
=
inpt
.
categories
)
else
:
return
inpt
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
if
self
.
num_categories
==
-
1
:
if
self
.
num_categories
==
-
1
:
...
...
torchvision/prototype/transforms/functional/__init__.py
View file @
020eafe1
...
@@ -106,12 +106,6 @@ from ._geometry import (
...
@@ -106,12 +106,6 @@ from ._geometry import (
vertical_flip_segmentation_mask
,
vertical_flip_segmentation_mask
,
)
)
from
._misc
import
gaussian_blur
,
gaussian_blur_image_pil
,
gaussian_blur_image_tensor
,
normalize
,
normalize_image_tensor
from
._misc
import
gaussian_blur
,
gaussian_blur_image_pil
,
gaussian_blur_image_tensor
,
normalize
,
normalize_image_tensor
from
._type_conversion
import
(
from
._type_conversion
import
decode_image_with_pil
,
decode_video_with_av
,
to_image_pil
,
to_image_tensor
decode_image_with_pil
,
decode_video_with_av
,
label_to_one_hot
,
to_image_pil
,
to_image_tensor
,
)
from
._deprecated
import
rgb_to_grayscale
,
to_grayscale
# usort: skip
from
._deprecated
import
rgb_to_grayscale
,
to_grayscale
# usort: skip
torchvision/prototype/transforms/functional/_type_conversion.py
View file @
020eafe1
...
@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, Tuple, Union
...
@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
PIL.Image
import
PIL.Image
import
torch
import
torch
from
torch.nn.functional
import
one_hot
from
torchvision.io.video
import
read_video
from
torchvision.io.video
import
read_video
from
torchvision.prototype.utils._internal
import
ReadOnlyTensorBuffer
from
torchvision.prototype.utils._internal
import
ReadOnlyTensorBuffer
from
torchvision.transforms
import
functional
as
_F
from
torchvision.transforms
import
functional
as
_F
...
@@ -22,10 +21,6 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
...
@@ -22,10 +21,6 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
return
read_video
(
ReadOnlyTensorBuffer
(
encoded_video
))
# type: ignore[arg-type]
return
read_video
(
ReadOnlyTensorBuffer
(
encoded_video
))
# type: ignore[arg-type]
def
label_to_one_hot
(
label
:
torch
.
Tensor
,
*
,
num_categories
:
int
)
->
torch
.
Tensor
:
return
one_hot
(
label
,
num_classes
=
num_categories
)
# type: ignore[no-any-return]
def
to_image_tensor
(
image
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
copy
:
bool
=
False
)
->
torch
.
Tensor
:
def
to_image_tensor
(
image
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
copy
:
bool
=
False
)
->
torch
.
Tensor
:
if
isinstance
(
image
,
np
.
ndarray
):
if
isinstance
(
image
,
np
.
ndarray
):
image
=
torch
.
from_numpy
(
image
)
image
=
torch
.
from_numpy
(
image
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment