Commit 9c326bb3 authored by Owen Wang's avatar Owen Wang Committed by Facebook GitHub Bot
Browse files

allow reading .npy files for seg masks

Summary: Allow reading `.npy` format binary masks shaped (H, W,) in addition to `.png` image masks shaped (H, W, C).

Reviewed By: wat3rBro

Differential Revision: D30136542

fbshipit-source-id: 56df5a766ab15b6808a1327815857e5d38eac910
parent 8b03f9aa
...@@ -4,17 +4,19 @@ ...@@ -4,17 +4,19 @@
import copy import copy
import logging import logging
from io import BytesIO
import numpy as np import numpy as np
import torch import torch
from d2go.data.dataset_mappers.data_reading import (
read_image_with_prefetch,
read_sem_seg_file_with_prefetch,
)
from d2go.utils.helper import retryable from d2go.utils.helper import retryable
from detectron2.data import detection_utils as utils, transforms as T from detectron2.data import detection_utils as utils, transforms as T
from detectron2.data.transforms.augmentation import ( from detectron2.data.transforms.augmentation import (
AugInput, AugInput,
AugmentationList, AugmentationList,
) )
from PIL import Image
from .build import D2GO_DATA_MAPPER_REGISTRY from .build import D2GO_DATA_MAPPER_REGISTRY
...@@ -24,16 +26,6 @@ PREFETCHED_FILE_NAME = "prefetch_image" ...@@ -24,16 +26,6 @@ PREFETCHED_FILE_NAME = "prefetch_image"
PREFETCHED_SEM_SEG_FILE_NAME = "prefetch_sem_seg" PREFETCHED_SEM_SEG_FILE_NAME = "prefetch_sem_seg"
def read_image_with_prefetch(file_name, format=None, prefetched=None):
if prefetched is None:
return utils.read_image(file_name, format)
image = Image.open(BytesIO(prefetched.numpy().view()))
# work around this bug: https://github.com/python-pillow/Pillow/issues/3973
image = utils._apply_exif_orientation(image)
return utils.convert_PIL_to_numpy(image, format)
@D2GO_DATA_MAPPER_REGISTRY.register() @D2GO_DATA_MAPPER_REGISTRY.register()
class D2GoDatasetMapper(object): class D2GoDatasetMapper(object):
def __init__(self, cfg, is_train=True, image_loader=None, tfm_gens=None): def __init__(self, cfg, is_train=True, image_loader=None, tfm_gens=None):
...@@ -176,11 +168,12 @@ class D2GoDatasetMapper(object): ...@@ -176,11 +168,12 @@ class D2GoDatasetMapper(object):
dataset_dict["instances"] = utils.filter_empty_instances(instances) dataset_dict["instances"] = utils.filter_empty_instances(instances)
if "sem_seg_file_name" in dataset_dict: if "sem_seg_file_name" in dataset_dict:
sem_seg_gt = read_image_with_prefetch( sem_seg_gt = read_sem_seg_file_with_prefetch(
dataset_dict.pop("sem_seg_file_name"), dataset_dict.pop("sem_seg_file_name"),
"L",
prefetched=dataset_dict.get(PREFETCHED_SEM_SEG_FILE_NAME, None), prefetched=dataset_dict.get(PREFETCHED_SEM_SEG_FILE_NAME, None),
).squeeze(2) )
if len(sem_seg_gt.shape) > 2:
sem_seg_gt = sem_seg_gt.squeeze(2)
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt) sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
dataset_dict["sem_seg"] = sem_seg_gt dataset_dict["sem_seg"] = sem_seg_gt
......
from io import BytesIO
import numpy as np
from detectron2.data import detection_utils as utils
from detectron2.utils.file_io import PathManager
from PIL import Image
def read_image_with_prefetch(file_name, format=None, prefetched=None):
if prefetched is None:
return utils.read_image(file_name, format)
image = Image.open(BytesIO(prefetched.numpy().view()))
# work around this bug: https://github.com/python-pillow/Pillow/issues/3973
image = utils._apply_exif_orientation(image)
return utils.convert_PIL_to_numpy(image, format)
def read_sem_seg_file_with_prefetch(file_name: str, prefetched=None):
"""
Segmentation mask annotations can be stored as:
.PNG files
.npy uncompressed numpy files
"""
assert file_name.endswith(".png") or file_name.endswith(".npy")
sem_seg_type = file_name[-len(".---") :]
if sem_seg_type == ".png":
return read_image_with_prefetch(file_name, format="L", prefetched=prefetched)
elif sem_seg_type == ".npy":
if prefetched is None:
with PathManager.open(file_name, "rb") as f:
return np.load(f)
else:
return prefetched.numpy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment