Unverified Commit 7a1281a6 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Clean-up Label.to_categories (#6419)

* [proto] Clean-up Label.to_categories

* Fixed flake8
parent 4db84b04
from __future__ import annotations from __future__ import annotations
from typing import Any, cast, Optional, Sequence, Union from typing import Any, Optional, Sequence, Union
import torch import torch
from torchvision.prototype.utils._internal import apply_recursively from torch.utils._pytree import tree_map
from ._feature import _Feature from ._feature import _Feature
...@@ -43,10 +43,10 @@ class Label(_Feature): ...@@ -43,10 +43,10 @@ class Label(_Feature):
return cls(categories.index(category), categories=categories, **kwargs) return cls(categories.index(category), categories=categories, **kwargs)
def to_categories(self) -> Any: def to_categories(self) -> Any:
if not self.categories: if self.categories is None:
raise RuntimeError() raise RuntimeError("Label does not have categories")
return apply_recursively(lambda idx: cast(Sequence[str], self.categories)[idx], self.tolist()) return tree_map(lambda idx: self.categories[idx], self.tolist())
class OneHotLabel(_Feature): class OneHotLabel(_Feature):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment