Unverified Commit 5bcf9f68 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use the same base class for Label and OneHotLabel (#6480)

* use the same base class for Label and OneHotLabel

* mypy

* rename typevar
parent 39fe34a2
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional, Sequence, Union from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -8,40 +8,45 @@ from torch.utils._pytree import tree_map ...@@ -8,40 +8,45 @@ from torch.utils._pytree import tree_map
from ._feature import _Feature from ._feature import _Feature
class Label(_Feature): L = TypeVar("L", bound="_LabelBase")
class _LabelBase(_Feature):
categories: Optional[Sequence[str]] categories: Optional[Sequence[str]]
def __new__( def __new__(
cls, cls: Type[L],
data: Any, data: Any,
*, *,
categories: Optional[Sequence[str]] = None, categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> Label: ) -> L:
label = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) label_base = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
label.categories = categories label_base.categories = categories
return label return label_base
@classmethod @classmethod
def new_like(cls, other: Label, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any) -> Label: def new_like(cls: Type[L], other: L, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any) -> L:
return super().new_like( return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs other, data, categories=categories if categories is not None else other.categories, **kwargs
) )
@classmethod @classmethod
def from_category( def from_category(
cls, cls: Type[L],
category: str, category: str,
*, *,
categories: Sequence[str], categories: Sequence[str],
**kwargs: Any, **kwargs: Any,
) -> Label: ) -> L:
return cls(categories.index(category), categories=categories, **kwargs) return cls(categories.index(category), categories=categories, **kwargs)
class Label(_LabelBase):
def to_categories(self) -> Any: def to_categories(self) -> Any:
if self.categories is None: if self.categories is None:
raise RuntimeError("Label does not have categories") raise RuntimeError("Label does not have categories")
...@@ -49,9 +54,7 @@ class Label(_Feature): ...@@ -49,9 +54,7 @@ class Label(_Feature):
return tree_map(lambda idx: self.categories[idx], self.tolist()) return tree_map(lambda idx: self.categories[idx], self.tolist())
class OneHotLabel(_Feature): class OneHotLabel(_LabelBase):
categories: Optional[Sequence[str]]
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
...@@ -61,19 +64,11 @@ class OneHotLabel(_Feature): ...@@ -61,19 +64,11 @@ class OneHotLabel(_Feature):
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> OneHotLabel: ) -> OneHotLabel:
one_hot_label = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) one_hot_label = super().__new__(
cls, data, categories=categories, dtype=dtype, device=device, requires_grad=requires_grad
)
if categories is not None and len(categories) != one_hot_label.shape[-1]: if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError() raise ValueError()
one_hot_label.categories = categories
return one_hot_label return one_hot_label
@classmethod
def new_like(
cls, other: OneHotLabel, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any
) -> OneHotLabel:
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment