Unverified Commit 874581cb authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove vanilla tensors from prototype datasets samples (#5018)

parent 0aa3717d
......@@ -32,23 +32,10 @@ from torchvision.prototype.datasets.utils._internal import (
getitem,
path_accessor,
)
from torchvision.prototype.features import BoundingBox, Label
from torchvision.prototype.features._feature import DEFAULT
from torchvision.prototype.features import BoundingBox, Label, Feature
from torchvision.prototype.utils._internal import FrozenMapping
class CocoLabel(Label):
super_category: Optional[str]
@classmethod
def _parse_meta_data(
cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment]
super_category: Optional[str] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
return dict(category=(category, None), super_category=(super_category, None))
class Coco(Dataset):
def _make_info(self) -> DatasetInfo:
name = "coco"
......@@ -111,27 +98,24 @@ class Coco(Dataset):
categories = [self.info.categories[label] for label in labels]
return dict(
# TODO: create a segmentation feature
segmentations=torch.stack(
segmentations=Feature(
torch.stack(
[
self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size)
for ann in anns
]
)
),
areas=torch.tensor([ann["area"] for ann in anns]),
crowds=torch.tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool),
areas=Feature([ann["area"] for ann in anns]),
crowds=Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns],
format="xywh",
image_size=image_size,
),
labels=[
CocoLabel(
label,
category=category,
super_category=self.info.extra.category_to_super_category[category],
)
for label, category in zip(labels, categories)
],
labels=Label(labels),
categories=categories,
super_categories=[self.info.extra.category_to_super_category[category] for category in categories],
ann_ids=[ann["id"] for ann in anns],
)
......@@ -141,7 +125,12 @@ class Coco(Dataset):
ann_ids=[ann["id"] for ann in anns],
)
_ANN_DECODERS = OrderedDict([("instances", _decode_instances_anns), ("captions", _decode_captions_ann)])
_ANN_DECODERS = OrderedDict(
[
("instances", _decode_instances_anns),
("captions", _decode_captions_ann),
]
)
_META_FILE_PATTERN = re.compile(
fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
......
......@@ -12,7 +12,7 @@ DEFAULT = object()
class Feature(torch.Tensor):
_META_ATTRS: Set[str]
_META_ATTRS: Set[str] = set()
_meta_data: Dict[str, Any]
def __init_subclass__(cls):
......
......@@ -360,7 +360,13 @@ class Transform(nn.Module):
else:
feature_type = type(sample)
if not self.supports(feature_type):
if not issubclass(feature_type, features.Feature) or feature_type in self.NO_OP_FEATURE_TYPES:
if (
not issubclass(feature_type, features.Feature)
# issubclass is not a strict check, but also allows the type checked against. Thus, we need to
# check it separately
or feature_type is features.Feature
or feature_type in self.NO_OP_FEATURE_TYPES
):
return sample
raise TypeError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment