Unverified Commit 5bcf9f68 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use the same base class for Label and OneHotLabel (#6480)

* use the same base class for Label and OneHotLabel

* mypy

* rename typevar
parent 39fe34a2
from __future__ import annotations
from typing import Any, Optional, Sequence, Union
from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch
from torch.utils._pytree import tree_map
......@@ -8,40 +8,45 @@ from torch.utils._pytree import tree_map
from ._feature import _Feature
class Label(_Feature):
L = TypeVar("L", bound="_LabelBase")
class _LabelBase(_Feature):
categories: Optional[Sequence[str]]
def __new__(
cls,
cls: Type[L],
data: Any,
*,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Label:
label = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
) -> L:
label_base = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
label.categories = categories
label_base.categories = categories
return label
return label_base
@classmethod
def new_like(cls, other: Label, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any) -> Label:
def new_like(cls: Type[L], other: L, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any) -> L:
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)
@classmethod
def from_category(
cls,
cls: Type[L],
category: str,
*,
categories: Sequence[str],
**kwargs: Any,
) -> Label:
) -> L:
return cls(categories.index(category), categories=categories, **kwargs)
class Label(_LabelBase):
def to_categories(self) -> Any:
if self.categories is None:
raise RuntimeError("Label does not have categories")
......@@ -49,9 +54,7 @@ class Label(_Feature):
return tree_map(lambda idx: self.categories[idx], self.tolist())
class OneHotLabel(_Feature):
categories: Optional[Sequence[str]]
class OneHotLabel(_LabelBase):
def __new__(
cls,
data: Any,
......@@ -61,19 +64,11 @@ class OneHotLabel(_Feature):
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> OneHotLabel:
one_hot_label = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
one_hot_label = super().__new__(
cls, data, categories=categories, dtype=dtype, device=device, requires_grad=requires_grad
)
if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError()
one_hot_label.categories = categories
return one_hot_label
@classmethod
def new_like(
cls, other: OneHotLabel, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any
) -> OneHotLabel:
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment