"docs/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "cd1728d09bd6d4288b5833249ba09b00529809f9"
Unverified Commit 874581cb authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove vanilla tensors from prototype datasets samples (#5018)

parent 0aa3717d
...@@ -32,23 +32,10 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -32,23 +32,10 @@ from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
path_accessor, path_accessor,
) )
from torchvision.prototype.features import BoundingBox, Label from torchvision.prototype.features import BoundingBox, Label, Feature
from torchvision.prototype.features._feature import DEFAULT
from torchvision.prototype.utils._internal import FrozenMapping from torchvision.prototype.utils._internal import FrozenMapping
class CocoLabel(Label):
super_category: Optional[str]
@classmethod
def _parse_meta_data(
cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment]
super_category: Optional[str] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
return dict(category=(category, None), super_category=(super_category, None))
class Coco(Dataset): class Coco(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
name = "coco" name = "coco"
...@@ -111,27 +98,24 @@ class Coco(Dataset): ...@@ -111,27 +98,24 @@ class Coco(Dataset):
categories = [self.info.categories[label] for label in labels] categories = [self.info.categories[label] for label in labels]
return dict( return dict(
# TODO: create a segmentation feature # TODO: create a segmentation feature
segmentations=torch.stack( segmentations=Feature(
[ torch.stack(
self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size) [
for ann in anns self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size)
] for ann in anns
]
)
), ),
areas=torch.tensor([ann["area"] for ann in anns]), areas=Feature([ann["area"] for ann in anns]),
crowds=torch.tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool), crowds=Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBox( bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns], [ann["bbox"] for ann in anns],
format="xywh", format="xywh",
image_size=image_size, image_size=image_size,
), ),
labels=[ labels=Label(labels),
CocoLabel( categories=categories,
label, super_categories=[self.info.extra.category_to_super_category[category] for category in categories],
category=category,
super_category=self.info.extra.category_to_super_category[category],
)
for label, category in zip(labels, categories)
],
ann_ids=[ann["id"] for ann in anns], ann_ids=[ann["id"] for ann in anns],
) )
...@@ -141,7 +125,12 @@ class Coco(Dataset): ...@@ -141,7 +125,12 @@ class Coco(Dataset):
ann_ids=[ann["id"] for ann in anns], ann_ids=[ann["id"] for ann in anns],
) )
_ANN_DECODERS = OrderedDict([("instances", _decode_instances_anns), ("captions", _decode_captions_ann)]) _ANN_DECODERS = OrderedDict(
[
("instances", _decode_instances_anns),
("captions", _decode_captions_ann),
]
)
_META_FILE_PATTERN = re.compile( _META_FILE_PATTERN = re.compile(
fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json" fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
......
...@@ -12,7 +12,7 @@ DEFAULT = object() ...@@ -12,7 +12,7 @@ DEFAULT = object()
class Feature(torch.Tensor): class Feature(torch.Tensor):
_META_ATTRS: Set[str] _META_ATTRS: Set[str] = set()
_meta_data: Dict[str, Any] _meta_data: Dict[str, Any]
def __init_subclass__(cls): def __init_subclass__(cls):
......
...@@ -360,7 +360,13 @@ class Transform(nn.Module): ...@@ -360,7 +360,13 @@ class Transform(nn.Module):
else: else:
feature_type = type(sample) feature_type = type(sample)
if not self.supports(feature_type): if not self.supports(feature_type):
if not issubclass(feature_type, features.Feature) or feature_type in self.NO_OP_FEATURE_TYPES: if (
not issubclass(feature_type, features.Feature)
# issubclass is not a strict check, but also allows the type checked against. Thus, we need to
# check it separately
or feature_type is features.Feature
or feature_type in self.NO_OP_FEATURE_TYPES
):
return sample return sample
raise TypeError( raise TypeError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment