"src/array/cuda/array_scatter.hip" did not exist on "44089c8b4d4db4ca71e816e0de50dca972dbabdb"
Unverified Commit 020eafe1 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Removed F.label_to_one_hot and added tests for LabelToOneHot (#6483)

parent b6feccbc
...@@ -1593,3 +1593,14 @@ class TestLinearTransformation: ...@@ -1593,3 +1593,14 @@ class TestLinearTransformation:
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
assert output.unique() == 3 * 8 * 8 assert output.unique() == 3 * 8 * 8
assert output.dtype == inpt.dtype assert output.dtype == inpt.dtype
class TestLabelToOneHot:
def test__transform(self):
categories = ["apple", "pear", "pineapple"]
labels = features.Label(torch.tensor([0, 1, 2, 1]), categories=categories)
transform = transforms.LabelToOneHot()
ohe_labels = transform(labels)
assert isinstance(ohe_labels, features.OneHotLabel)
assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import one_hot
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
...@@ -20,19 +21,18 @@ class DecodeImage(Transform): ...@@ -20,19 +21,18 @@ class DecodeImage(Transform):
class LabelToOneHot(Transform): class LabelToOneHot(Transform):
_transformed_types = (features.Label,)
def __init__(self, num_categories: int = -1): def __init__(self, num_categories: int = -1):
super().__init__() super().__init__()
self.num_categories = num_categories self.num_categories = num_categories
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: features.Label, params: Dict[str, Any]) -> features.OneHotLabel:
if isinstance(inpt, features.Label): num_categories = self.num_categories
num_categories = self.num_categories if num_categories == -1 and inpt.categories is not None:
if num_categories == -1 and inpt.categories is not None: num_categories = len(inpt.categories)
num_categories = len(inpt.categories) output = one_hot(inpt, num_classes=num_categories)
output = F.label_to_one_hot(inpt, num_categories=num_categories) return features.OneHotLabel(output, categories=inpt.categories)
return features.OneHotLabel(output, categories=inpt.categories)
else:
return inpt
def extra_repr(self) -> str: def extra_repr(self) -> str:
if self.num_categories == -1: if self.num_categories == -1:
......
...@@ -106,12 +106,6 @@ from ._geometry import ( ...@@ -106,12 +106,6 @@ from ._geometry import (
vertical_flip_segmentation_mask, vertical_flip_segmentation_mask,
) )
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
from ._type_conversion import ( from ._type_conversion import decode_image_with_pil, decode_video_with_av, to_image_pil, to_image_tensor
decode_image_with_pil,
decode_video_with_av,
label_to_one_hot,
to_image_pil,
to_image_tensor,
)
from ._deprecated import rgb_to_grayscale, to_grayscale # usort: skip from ._deprecated import rgb_to_grayscale, to_grayscale # usort: skip
...@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, Tuple, Union ...@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import one_hot
from torchvision.io.video import read_video from torchvision.io.video import read_video
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
...@@ -22,10 +21,6 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor ...@@ -22,10 +21,6 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type] return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]
def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return]
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor: def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor:
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):
image = torch.from_numpy(image) image = torch.from_numpy(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment