Commit bf8d84b2 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support using prefetched files in the mapper

Summary: The default mapper may load "file_name" and "sem_seg_file_name" from `dataset_dict`, when prefetching them from manifold, we no longer need to load them because they're already fetched. This diff adds two more fields for holding those pre-fetched data, and make the mapper work in both cases.

Reviewed By: newstzpz

Differential Revision: D26972340

fbshipit-source-id: 63f6dc809d321e149aa5adf9f92c3ace07cbf2a7
parent 6aec097e
...@@ -4,20 +4,35 @@ ...@@ -4,20 +4,35 @@
import copy import copy
import logging import logging
from io import BytesIO
import numpy as np import numpy as np
import torch import torch
from d2go.utils.helper import retryable
from detectron2.data import detection_utils as utils, transforms as T from detectron2.data import detection_utils as utils, transforms as T
from detectron2.data.transforms.augmentation import ( from detectron2.data.transforms.augmentation import (
AugInput, AugInput,
AugmentationList, AugmentationList,
) )
from d2go.utils.helper import retryable from PIL import Image
from .build import D2GO_DATA_MAPPER_REGISTRY from .build import D2GO_DATA_MAPPER_REGISTRY
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PREFETCHED_FILE_NAME = "prefetch_image"
PREFETCHED_SEM_SEG_FILE_NAME = "prefetch_sem_seg"
def read_image_with_prefetch(file_name, format=None, prefetched=None):
if prefetched is None:
return utils.read_image(file_name, format)
image = Image.open(BytesIO(prefetched.numpy().view()))
# work around this bug: https://github.com/python-pillow/Pillow/issues/3973
image = utils._apply_exif_orientation(image)
return utils.convert_PIL_to_numpy(image, format)
@D2GO_DATA_MAPPER_REGISTRY.register() @D2GO_DATA_MAPPER_REGISTRY.register()
class D2GoDatasetMapper(object): class D2GoDatasetMapper(object):
...@@ -161,8 +176,10 @@ class D2GoDatasetMapper(object): ...@@ -161,8 +176,10 @@ class D2GoDatasetMapper(object):
dataset_dict["instances"] = utils.filter_empty_instances(instances) dataset_dict["instances"] = utils.filter_empty_instances(instances)
if "sem_seg_file_name" in dataset_dict: if "sem_seg_file_name" in dataset_dict:
sem_seg_gt = utils.read_image( sem_seg_gt = read_image_with_prefetch(
dataset_dict.pop("sem_seg_file_name"), "L" dataset_dict.pop("sem_seg_file_name"),
"L",
prefetched=dataset_dict.get(PREFETCHED_SEM_SEG_FILE_NAME, None),
).squeeze(2) ).squeeze(2)
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt) sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
...@@ -242,7 +259,11 @@ class D2GoDatasetMapper(object): ...@@ -242,7 +259,11 @@ class D2GoDatasetMapper(object):
def _read_image(self, dataset_dict, format=None): def _read_image(self, dataset_dict, format=None):
if not (self.image_loader and self.image_loader.support(dataset_dict)): if not (self.image_loader and self.image_loader.support(dataset_dict)):
# fallback to use D2's read_image # fallback to use D2's read_image
image = utils.read_image(dataset_dict["file_name"], format=format) image = read_image_with_prefetch(
dataset_dict["file_name"],
format=format,
prefetched=dataset_dict.get(PREFETCHED_FILE_NAME),
)
if self.backfill_size: if self.backfill_size:
h, w, _ = image.shape h, w, _ = image.shape
dataset_dict["width"] = w dataset_dict["width"] = w
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment