Commit f23248c0 authored by facebook-github-bot's avatar facebook-github-bot
Browse files

Initial commit

fbshipit-source-id: f4a8ba78691d8cf46e003ef0bd2e95f170932778
parents
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
def _cache_json_file(json_file):
# TODO: entirely rely on PathManager for caching
json_file = os.fspath(json_file)
return json_file
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from d2go.config import CfgNode as CN
def add_d2go_data_default_configs(_C):
_C.D2GO_DATA = CN()
# Config for "detectron2go.data.extended_coco.extended_coco_load"
_C.D2GO_DATA.DATASETS = CN()
# List of class names to use when loading the data, this applies to train
# and test separately. Default value means using all classes, otherwise it'll create
# new json file containing only given categories.
_C.D2GO_DATA.DATASETS.TRAIN_CATEGORIES = ()
_C.D2GO_DATA.DATASETS.TEST_CATEGORIES = ()
# Register a list of COCO datasets in config
_C.D2GO_DATA.DATASETS.COCO_INJECTION = CN()
_C.D2GO_DATA.DATASETS.COCO_INJECTION.NAMES = []
_C.D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS = []
_C.D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES = []
# On-the-fly register a list of datasets located under detectron2go/datasets
# by specifying the filename (without .py).
_C.D2GO_DATA.DATASETS.DYNAMIC_DATASETS = []
# TODO: potentially add this config
# # List of extra keys in annotation, the item will be forwarded by
# # extended_coco_load.
# _C.D2GO_DATA.DATASETS.ANNOTATION_FIELDS_TO_FORWARD = ()
# Config for D2GoDatasetMapper
_C.D2GO_DATA.MAPPER = CN()
# dataset mapper name
_C.D2GO_DATA.MAPPER.NAME = "D2GoDatasetMapper"
# When enabled, image item from json dataset doesn't need to have width/hegiht,
# they will be backfilled once image is loaded. This may cause issue when
# width/hegiht is acutally been used by extended_coco_load, eg. grouping
# by aspect ratio.
_C.D2GO_DATA.MAPPER.BACKFILL_SIZE = False
_C.D2GO_DATA.MAPPER.RETRY = 3
_C.D2GO_DATA.MAPPER.CATCH_EXCEPTION = True
_C.D2GO_DATA.AUG_OPS = CN()
# List of transforms that are represented by string. Each string starts with
# a registered name in TRANSFORM_OP_REGISTRY, optionally followed by a string
# argument (separated by "::") which can be used for initializing the
# transform object. See build_transform_gen for the detail.
# Some examples are:
# example 1: RandomFlipOp
# example 2: RandomFlipOp::{}
# example 3: RandomFlipOp::{"prob":0.5}
# example 4: RandomBrightnessOp::{"intensity_min":1.0, "intensity_max":2.0}
# NOTE: search "example repr:" in fbcode for examples.
_C.D2GO_DATA.AUG_OPS.TRAIN = ["ResizeShortestEdgeOp", "RandomFlipOp"]
_C.D2GO_DATA.AUG_OPS.TEST = ["ResizeShortestEdgeOp"]
_C.D2GO_DATA.TEST = CN()
# Evaluate on the first specified number of images for each datset during
# testing, default value 0 means using all images.
# NOTE: See maybe_subsample_n_images for details.
_C.D2GO_DATA.TEST.MAX_IMAGES = 0
_C.D2GO_DATA.TEST.SUBSET_SAMPLING = "frontmost" # one of {"frontmost", "random"}
return _C
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging
# TODO: Deprecate this file in favor of the module detectron2go.data.dataset_mappers
from d2go.data.dataset_mappers import D2GoDatasetMapper # noqa
from PIL import Image
logger = logging.getLogger(__name__)
_IMAGE_LOADER_REGISTRY = {}
def register_uri_image_loader(scheme, loader):
"""
Image can be represented as "scheme://path", image will be retrived by calling
Image.open(loader(path)).
"""
logger.info(
"Register image loader for scheme: {} with loader: {}".format(scheme, loader)
)
_IMAGE_LOADER_REGISTRY[scheme] = loader
# TODO (T62922909): remove UniversalResourceLoader and use PathManager
class UniversalResourceLoader(object):
def __init__(self):
self._image_loader_func_map = copy.deepcopy(_IMAGE_LOADER_REGISTRY)
@staticmethod
def parse_path(uri):
SCHEME_SEPARATOR = "://"
if uri.count(SCHEME_SEPARATOR) < 1:
# this should be a normal file name, use full string as path
return "file", uri
scheme, path = uri.split(SCHEME_SEPARATOR, maxsplit=1)
return scheme, path
def get_file(self, uri):
scheme, path = self.parse_path(uri)
if scheme not in self._image_loader_func_map:
raise RuntimeError(
"No loader for scheme {} in UniversalResourceLoader for uri: {}".format(
scheme, uri
)
)
loader = self._image_loader_func_map[scheme]
return loader(path)
def support(self, dataset_dict):
uri = dataset_dict["file_name"]
scheme, _ = self.parse_path(uri)
return scheme in self._image_loader_func_map
def __call__(self, dataset_dict):
uri = dataset_dict["file_name"]
fp = self.get_file(uri)
return Image.open(fp)
def __repr__(self):
return "UniversalResourceLoader(schemes={})".format(
list(self._image_loader_func_map.keys())
)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .build import D2GO_DATA_MAPPER_REGISTRY, build_dataset_mapper # noqa
from .d2go_dataset_mapper import D2GoDatasetMapper # noqa
from .rotated_dataset_mapper import RotatedDatasetMapper # noqa
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.utils.registry import Registry
D2GO_DATA_MAPPER_REGISTRY = Registry("D2GO_DATA_MAPPER")
def build_dataset_mapper(cfg, is_train, *args, **kwargs):
name = cfg.D2GO_DATA.MAPPER.NAME
return D2GO_DATA_MAPPER_REGISTRY.get(name)(cfg, is_train, *args, **kwargs)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging
import numpy as np
import torch
from detectron2.data import detection_utils as utils, transforms as T
from detectron2.data.transforms.augmentation import (
AugInput,
AugmentationList,
)
from d2go.utils.helper import retryable
from .build import D2GO_DATA_MAPPER_REGISTRY
logger = logging.getLogger(__name__)
@D2GO_DATA_MAPPER_REGISTRY.register()
class D2GoDatasetMapper(object):
def __init__(self, cfg, is_train=True, image_loader=None, tfm_gens=None):
self.tfm_gens = (
tfm_gens
if tfm_gens is not None
else utils.build_transform_gen(cfg, is_train)
)
if cfg.INPUT.CROP.ENABLED and is_train:
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
# D2GO NOTE: when INPUT.CROP.ENABLED, don't allow using RandomCropOp
assert all(not isinstance(gen, T.RandomCrop) for gen in self.tfm_gens)
else:
self.crop_gen = None
# fmt: off
self.img_format = cfg.INPUT.FORMAT # noqa
self.mask_on = cfg.MODEL.MASK_ON # noqa
self.mask_format = cfg.INPUT.MASK_FORMAT # noqa
self.keypoint_on = cfg.MODEL.KEYPOINT_ON # noqa
# fmt: on
if self.keypoint_on and is_train:
# Flip only makes sense in training
self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(
cfg.DATASETS.TRAIN
)
else:
self.keypoint_hflip_indices = None
self.load_proposals = cfg.MODEL.LOAD_PROPOSALS
if self.load_proposals:
self.proposal_min_box_size = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
self.proposal_topk = (
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
if is_train
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
)
self.is_train = is_train
# Setup image loader:
self.image_loader = image_loader
self.backfill_size = cfg.D2GO_DATA.MAPPER.BACKFILL_SIZE
self.retry = cfg.D2GO_DATA.MAPPER.RETRY
self.catch_exception = cfg.D2GO_DATA.MAPPER.CATCH_EXCEPTION
if self.backfill_size:
if cfg.DATALOADER.ASPECT_RATIO_GROUPING:
logger.warning(
"ASPECT_RATIO_GROUPING may not work if image's width & height"
" are not given in json dataset when calling extended_coco_load,"
" if you encounter issue, consider disable ASPECT_RATIO_GROUPING."
)
self._error_count = 0
self._total_counts = 0
self._error_types = {}
def _original_call(self, dataset_dict):
"""
Modified from detectron2's original __call__ in DatasetMapper
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = self._read_image(dataset_dict, format=self.img_format)
if not self.backfill_size:
utils.check_image_size(dataset_dict, image)
image, dataset_dict = self._custom_transform(image, dataset_dict)
inputs = AugInput(image=image)
if "annotations" not in dataset_dict:
transforms = AugmentationList(
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens
)(inputs)
image = inputs.image
else:
# pass additional arguments, will only be used when the Augmentation
# takes `annotations` as input
inputs.annotations = dataset_dict["annotations"]
# Crop around an instance if there are instances in the image.
if self.crop_gen:
crop_tfm = utils.gen_crop_transform_with_instance(
self.crop_gen.get_crop_size(image.shape[:2]),
image.shape[:2],
np.random.choice(dataset_dict["annotations"]),
)
image = crop_tfm.apply_image(image)
transforms = AugmentationList(self.tfm_gens)(inputs)
image = inputs.image
if self.crop_gen:
transforms = crop_tfm + transforms
image_shape = image.shape[:2] # h, w
if image.ndim == 2:
image = np.expand_dims(image, 2)
dataset_dict["image"] = torch.as_tensor(
image.transpose(2, 0, 1).astype("float32")
)
# Can use uint8 if it turns out to be slow some day
if self.load_proposals:
utils.transform_proposals(
dataset_dict,
image_shape,
transforms,
proposal_topk=self.proposal_topk,
min_box_size=self.proposal_min_box_size,
)
if not self.is_train:
dataset_dict.pop("annotations", None)
dataset_dict.pop("sem_seg_file_name", None)
return dataset_dict
if "annotations" in dataset_dict:
for anno in dataset_dict["annotations"]:
if not self.mask_on:
anno.pop("segmentation", None)
if not self.keypoint_on:
anno.pop("keypoints", None)
annos = [
utils.transform_instance_annotations(
obj,
transforms,
image_shape,
keypoint_hflip_indices=self.keypoint_hflip_indices,
)
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = utils.annotations_to_instances(
annos, image_shape, mask_format=self.mask_format
)
# Create a tight bounding box from masks, useful when image is cropped
if self.crop_gen and instances.has("gt_masks"):
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
dataset_dict["instances"] = utils.filter_empty_instances(instances)
if "sem_seg_file_name" in dataset_dict:
sem_seg_gt = utils.read_image(
dataset_dict.pop("sem_seg_file_name"), "L"
).squeeze(2)
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
dataset_dict["sem_seg"] = sem_seg_gt
# extend standard D2 semantic segmentation to support multiple segmentation
# files, each file can represent a class
if "multi_sem_seg_file_names" in dataset_dict:
raise NotImplementedError()
if "_post_process_" in dataset_dict:
proc_func = dataset_dict.pop("_post_process_")
dataset_dict = proc_func(dataset_dict)
return dataset_dict
def __call__(self, dataset_dict):
self._total_counts += 1
@retryable(num_tries=self.retry, sleep_time=0.1)
def _f():
return self._original_call(dataset_dict)
if not self.catch_exception:
return _f()
try:
return _f()
except Exception as e:
self._error_count += 1
# if self._error_count % 10 == 1:
# # print the stacktrace for easier debugging
# traceback.print_exc()
error_type = type(e).__name__
self._error_types[error_type] = self._error_types.get(error_type, 0) + 1
if self._error_count % 100 == 0:
logger.warning(
"{}Error when applying transform for dataset_dict: {};"
" error rate {}/{} ({:.2f}%), msg: {}".format(
self._get_logging_prefix(),
dataset_dict,
self._error_count,
self._total_counts,
100.0 * self._error_count / self._total_counts,
repr(e),
)
)
self._log_error_type_stats()
# NOTE: the contract with MapDataset allows return `None` such that
# it'll randomly use other element in the dataset. We use this
# feature to handle error.
return None
def _get_logging_prefix(self):
worker_info = torch.utils.data.get_worker_info()
if not worker_info:
return ""
prefix = "[worker: {}/{}] ".format(worker_info.id, worker_info.num_workers)
return prefix
def _log_error_type_stats(self):
error_type_count_msgs = [
"{}: {}/{} ({}%)".format(
k, v, self._total_counts, 100.0 * v / self._total_counts
)
for k, v in self._error_types.items()
]
logger.warning(
"{}Error statistics:\n{}".format(
self._get_logging_prefix(), "\n".join(error_type_count_msgs)
)
)
def _read_image(self, dataset_dict, format=None):
if not (self.image_loader and self.image_loader.support(dataset_dict)):
# fallback to use D2's read_image
image = utils.read_image(dataset_dict["file_name"], format=format)
if self.backfill_size:
h, w, _ = image.shape
dataset_dict["width"] = w
dataset_dict["height"] = h
return image
image = self.image_loader(dataset_dict)
if self.backfill_size:
dataset_dict["width"] = image.width
dataset_dict["height"] = image.height
return utils.convert_PIL_to_numpy(image, format)
def _custom_transform(self, image, dataset_dict):
"""
Override this method to inject custom transform.
"""
return image, dataset_dict
def __repr__(self):
return (
self.__class__.__name__
+ ":\n"
+ "\n".join(
[
" is_train: {}".format(self.is_train),
" image_loader: {}".format(self.image_loader),
" tfm_gens: \n{}".format(
"\n".join([" - {}".format(x) for x in self.tfm_gens])
),
]
)
)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging
import numpy as np
import torch
from detectron2.data import detection_utils as utils, transforms as T
from detectron2.structures import BoxMode, Instances, RotatedBoxes
from d2go.data.dataset_mappers.d2go_dataset_mapper import D2GoDatasetMapper
from .build import D2GO_DATA_MAPPER_REGISTRY
logger = logging.getLogger(__name__)
@D2GO_DATA_MAPPER_REGISTRY.register()
class RotatedDatasetMapper(D2GoDatasetMapper):
def _original_call(self, dataset_dict):
"""
Modified from detectron2's original __call__ in DatasetMapper
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = self._read_image(dataset_dict, format=self.img_format)
if not self.backfill_size:
utils.check_image_size(dataset_dict, image)
if "annotations" not in dataset_dict:
image, transforms = T.apply_transform_gens(
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
)
else:
# Crop around an instance if there are instances in the image.
# USER: Remove if you don't use cropping
if self.crop_gen:
crop_tfm = utils.gen_crop_transform_with_instance(
self.crop_gen.get_crop_size(image.shape[:2]),
image.shape[:2],
np.random.choice(dataset_dict["annotations"]),
)
image = crop_tfm.apply_image(image)
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
if self.crop_gen:
transforms = crop_tfm + transforms
image_shape = image.shape[:2] # h, w
dataset_dict["image"] = torch.as_tensor(
image.transpose(2, 0, 1).astype("float32")
)
# Can use uint8 if it turns out to be slow some day
assert not self.load_proposals, "Not supported!"
if not self.is_train:
dataset_dict.pop("annotations", None)
dataset_dict.pop("sem_seg_file_name", None)
return dataset_dict
if "annotations" in dataset_dict:
for anno in dataset_dict["annotations"]:
if not self.mask_on:
anno.pop("segmentation", None)
if not self.keypoint_on:
anno.pop("keypoints", None)
# Convert dataset_dict["annotations"] to dataset_dict["instances"]
annotations = [
obj
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
# Convert either rotated box or horizontal box to XYWHA_ABS format
original_boxes = [
BoxMode.convert(
box=obj["bbox"],
from_mode=obj["bbox_mode"],
to_mode=BoxMode.XYWHA_ABS,
)
for obj in annotations
]
transformed_boxes = transforms.apply_rotated_box(
np.array(original_boxes, dtype=np.float64)
)
instances = Instances(image_shape)
instances.gt_classes = torch.tensor(
[obj["category_id"] for obj in annotations], dtype=torch.int64
)
instances.gt_boxes = RotatedBoxes(transformed_boxes)
instances.gt_boxes.clip(image_shape)
dataset_dict["instances"] = instances[instances.gt_boxes.nonempty()]
return dataset_dict
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import importlib
import logging
import os
from detectron2.data import DatasetCatalog, MetadataCatalog
from d2go.utils.helper import get_dir_path
from .extended_coco import coco_text_load, extended_coco_load
from .extended_lvis import extended_lvis_load
logger = logging.getLogger(__name__)
D2GO_DATASETS_BASE_MODULE = "d2go.datasets"
IM_DIR = "image_directory"
ANN_FN = "annotation_file"
def _import_dataset(module_name):
return importlib.import_module(
"{}.{}".format(D2GO_DATASETS_BASE_MODULE, module_name)
)
def _register_extended_coco(dataset_name, split_dict):
json_file = split_dict[ANN_FN]
image_root = split_dict[IM_DIR]
# 1. register a function which returns dicts
load_coco_json_func = functools.partial(
extended_coco_load,
json_file=json_file,
image_root=image_root,
dataset_name=dataset_name,
)
DatasetCatalog.register(dataset_name, load_coco_json_func)
# 2. Optionally, add metadata about this split,
# since they might be useful in evaluation, visualization or logging
evaluator_type = split_dict.get("evaluator_type", "coco")
meta_data = split_dict.get("meta_data", {})
MetadataCatalog.get(dataset_name).set(
evaluator_type=evaluator_type,
json_file=json_file,
image_root=image_root,
**meta_data
)
def _register_extended_lvis(dataset_name, split_dict):
json_file = split_dict[ANN_FN]
image_root = split_dict[IM_DIR]
# 1. register a function which returns dicts
load_lvis_json_func = functools.partial(
extended_lvis_load,
json_file=json_file,
image_root=image_root,
dataset_name=dataset_name,
)
DatasetCatalog.register(dataset_name, load_lvis_json_func)
# 2. Optionally, add metadata about this split,
# since they might be useful in evaluation, visualization or logging
evaluator_type = split_dict.get("evaluator_type", "lvis")
MetadataCatalog.get(dataset_name).set(
evaluator_type=evaluator_type, json_file=json_file, image_root=image_root
)
def _register_coco_text(dataset_name, split_dict):
source_json_file = split_dict[ANN_FN]
coco_json_file = "/tmp/{}.json".format(dataset_name)
ARCHIVE_FN = "archive_file"
# 1. register a function which returns dicts
DatasetCatalog.register(
dataset_name,
functools.partial(
coco_text_load,
coco_json_file=coco_json_file,
image_root=split_dict[IM_DIR],
source_json_file=source_json_file,
dataset_name=dataset_name,
archive_file=split_dict.get(ARCHIVE_FN, None),
),
)
# 2. Optionally, add metadata about this split,
# since they might be useful in evaluation, visualization or logging
evaluator_type = split_dict.get("evaluator_type", "coco")
MetadataCatalog.get(dataset_name).set(
json_file=coco_json_file,
image_root=split_dict[IM_DIR],
evaluator_type=evaluator_type,
)
def inject_coco_datasets(cfg):
names = cfg.D2GO_DATA.DATASETS.COCO_INJECTION.NAMES
im_dirs = cfg.D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS
json_files = cfg.D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES
assert len(names) == len(im_dirs) == len(json_files)
for name, im_dir, json_file in zip(names, im_dirs, json_files):
split_dict = {IM_DIR: im_dir, ANN_FN: json_file}
logger.info("Inject coco dataset {}: {}".format(name, split_dict))
_register_extended_coco(name, split_dict)
def register_dataset_split(dataset_name, split_dict):
"""
Register a dataset to detectron2's DatasetCatalog and MetadataCatalog.
"""
_DATASET_TYPE_LOAD_FUNC_MAP = {
"COCODataset": _register_extended_coco,
"COCOText": _register_coco_text,
"COCOTextDataset": _register_coco_text,
"LVISDataset": _register_extended_lvis,
}
factory = split_dict.get("DS_TYPE", "COCODataset")
_DATASET_TYPE_LOAD_FUNC_MAP[factory](
dataset_name=dataset_name, split_dict=split_dict
)
def register_json_datasets():
json_dataset_names = [
os.path.splitext(filename)[0]
for filename in os.listdir(
get_dir_path(D2GO_DATASETS_BASE_MODULE.replace(".", "/"))
)
if filename.startswith("json_dataset_")
]
json_dataset_names = [
x
for x in json_dataset_names
if x
not in [
"json_dataset_lvis",
"json_dataset_oculus_external",
"json_dataset_people_ai_foot_tracking",
]
]
# load all splits from json datasets
all_splits = {}
for dataset in json_dataset_names:
module = _import_dataset(dataset)
assert (
len(set(all_splits).intersection(set(module.DATASETS))) == 0
), "Name confliction when loading {}".format(dataset)
all_splits.update(module.DATASETS)
# register all splits
for split_name in all_splits:
split_dict = all_splits[split_name]
register_dataset_split(split_name, split_dict)
def register_builtin_datasets():
builtin_dataset_names = [
os.path.splitext(filename)[0]
for filename in os.listdir(
get_dir_path(D2GO_DATASETS_BASE_MODULE.replace(".", "/"))
)
if filename.startswith("builtin_dataset_")
]
for dataset in builtin_dataset_names:
_import_dataset(dataset)
def register_dynamic_datasets(cfg):
for dataset in cfg.D2GO_DATA.DATASETS.DYNAMIC_DATASETS:
assert dataset.startswith("dynamic_dataset_")
_import_dataset(dataset)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import logging
import os
import shlex
import subprocess
from collections import defaultdict
import detectron2.utils.comm as comm
from detectron2.data import MetadataCatalog
from detectron2.structures import BoxMode
from pycocotools.coco import COCO
from .cache_util import _cache_json_file
logger = logging.getLogger(__name__)
class InMemoryCOCO(COCO):
def __init__(self, loaded_json):
"""
In this in-memory version of COCO we don't load json from the file,
but direclty use a loaded_json instead. This approach improves
both robustness and efficiency, as when we convert from other formats
to COCO format, we don't need to save and re-load the json again.
"""
# load dataset
self.dataset = loaded_json
self.anns = {}
self.cats = {}
self.imgs = {}
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
self.createIndex()
def extract_archive_file(archive_fn, im_dir):
if not os.path.exists(im_dir) or not os.listdir(im_dir):
# Dataset is not deployed. Deploy it.
archive_fns = archive_fn
# A dataset may be composed of several tgz files, or only one.
# If one, make it into a list to make the code later more general
if not isinstance(archive_fns, list):
archive_fns = [archive_fns]
logger.info(
"Extracting datasets {} to local machine at {}".format(archive_fns, im_dir)
)
if not os.path.exists(im_dir):
os.makedirs(im_dir)
for archive_fn in archive_fns:
# Extract the tgz file directly into the target directory,
# without precopy.
# Note that the tgz file contains a root directory that
# we do not want, hence the strip-components=1
commandUnpack = (
"tar -mxzf {src_file} -C {tgt_dir} " "--strip-components=1"
).format(src_file=archive_fn, tgt_dir=im_dir)
assert not subprocess.call(shlex.split(commandUnpack)), "Failed to unpack"
logger.info("Extracted {}".format(archive_fn))
def convert_coco_text_to_coco_detection_json(
source_json, target_json, set_type=None, min_img_size=100, text_cat_id=1
):
"""
This function converts a COCOText style JSON to a COCODetection style
JSON.
For COCOText see: https://vision.cornell.edu/se3/coco-text-2/
For COCODetection see: http://cocodataset.org/#overview
"""
with open(source_json) as f:
coco_text_json = json.load(f)
coco_text_json["annotations"] = list(coco_text_json["anns"].values())
coco_text_json["images"] = list(coco_text_json["imgs"].values())
if set_type is not None:
# COCO Text style JSONs often mix test, train, and val sets.
# We need to make sure we only use the data type we want.
coco_text_json["images"] = [
x for x in coco_text_json["images"] if x["set"] == set_type
]
coco_text_json["categories"] = [{"name": "text", "id": text_cat_id}]
del coco_text_json["cats"]
del coco_text_json["imgs"]
del coco_text_json["anns"]
for ann in coco_text_json["annotations"]:
ann["category_id"] = text_cat_id
ann["iscrowd"] = 0
# Don't evaluate the model on illegible words
if set_type == "val" and ann["legibility"] != "legible":
ann["ignore"] = True
# Some datasets seem to have extremely small images which break downstream
# operations. If min_img_size is set, we can remove these.
coco_text_json["images"] = [
x
for x in coco_text_json["images"]
if x["height"] >= min_img_size and x["width"] >= min_img_size
]
os.makedirs(os.path.dirname(target_json), exist_ok=True)
with open(target_json, "w") as f:
json.dump(coco_text_json, f)
return coco_text_json
def convert_to_dict_list(image_root, id_map, imgs, anns, dataset_name=None):
num_instances_without_valid_segmentation = 0
dataset_dicts = []
count_ignore_image_root_warning = 0
for (img_dict, anno_dict_list) in zip(imgs, anns):
record = {}
# NOTE: besides using (relative path) in the "file_name" filed to represent
# the image resource, "extended coco" also supports using uri which
# represents an image using a single string, eg. "everstore_handle://xxx",
if "://" not in img_dict["file_name"]:
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
else:
if image_root is not None:
count_ignore_image_root_warning += 1
if count_ignore_image_root_warning == 1:
logger.warning(
(
"Found '://' in file_name: {}, ignore image_root: {}"
"(logged once per dataset)."
).format(img_dict["file_name"], image_root)
)
record["file_name"] = img_dict["file_name"]
if "height" in img_dict or "width" in img_dict:
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
image_id = record["image_id"] = img_dict["id"]
objs = []
for anno in anno_dict_list:
# Check that the image_id in this annotation is the same. This fails
# only when the data parsing logic or the annotation file is buggy.
assert anno["image_id"] == image_id
assert anno.get("ignore", 0) == 0
obj = {
field: anno[field]
# NOTE: maybe use MetadataCatalog for this
for field in ["iscrowd", "bbox", "keypoints", "category_id", "extras"]
if field in anno
}
if obj.get("category_id", None) not in id_map:
continue
segm = anno.get("segmentation", None)
if segm: # either list[list[float]] or dict(RLE)
if not isinstance(segm, dict):
# filter out invalid polygons (< 3 points)
segm = [
poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6
]
if len(segm) == 0:
num_instances_without_valid_segmentation += 1
continue # ignore this instance
obj["segmentation"] = segm
if len(obj["bbox"]) == 5:
obj["bbox_mode"] = BoxMode.XYWHA_ABS
else:
obj["bbox_mode"] = BoxMode.XYWH_ABS
if id_map:
obj["category_id"] = id_map[obj["category_id"]]
objs.append(obj)
record["annotations"] = objs
if dataset_name is not None:
record["dataset_name"] = dataset_name
dataset_dicts.append(record)
if count_ignore_image_root_warning > 0:
logger.warning(
"The 'ignore image_root: {}' warning occurred {} times".format(
image_root, count_ignore_image_root_warning
)
)
if num_instances_without_valid_segmentation > 0:
logger.warning(
"Filtered out {} instances without valid segmentation. "
"There might be issues in your dataset generation process.".format(
num_instances_without_valid_segmentation
)
)
return dataset_dicts
def coco_text_load(
coco_json_file,
image_root,
source_json_file=None,
dataset_name=None,
archive_file=None,
):
if archive_file is not None:
if comm.get_rank() == 0:
extract_archive_file(archive_file, image_root)
comm.synchronize()
if source_json_file is not None:
# Need to convert to coco detection format
loaded_json = convert_coco_text_to_coco_detection_json(
source_json_file, coco_json_file
)
return extended_coco_load(coco_json_file, image_root, dataset_name, loaded_json)
return extended_coco_load(
coco_json_file, image_root, dataset_name, loaded_json=None
)
def extended_coco_load(json_file, image_root, dataset_name=None, loaded_json=None):
"""
Load a json file with COCO's annotation format.
Currently only supports instance segmentation annotations.
Args:
json_file (str): full path to the json file in COCO annotation format.
image_root (str): the directory where the images in this json file exists.
dataset_name (str): the name of the dataset (e.g., "coco", "cityscapes").
If provided, this function will also put "thing_classes" into
the metadata associated with this dataset.
loaded_json (str): optional loaded json content, used in InMemoryCOCO to
avoid loading from json_file again.
Returns:
list[dict]: a list of dicts in "Detectron2 Dataset" format. (See DATASETS.md)
Notes:
1. This function does not read the image files.
The results do not have the "image" field.
2. When `dataset_name=='coco'`,
this function will translate COCO's
incontiguous category ids to contiguous ids in [0, 80).
"""
json_file = _cache_json_file(json_file)
if loaded_json is None:
coco_api = COCO(json_file)
else:
coco_api = InMemoryCOCO(loaded_json)
id_map = None
# Get filtered classes
all_cat_ids = coco_api.getCatIds()
all_cats = coco_api.loadCats(all_cat_ids)
# Setup classes to use for creating id map
classes_to_use = [c["name"] for c in sorted(all_cats, key=lambda x: x["id"])]
# Setup id map
id_map = {}
for cat_id, cat in zip(all_cat_ids, all_cats):
if cat["name"] in classes_to_use:
id_map[cat_id] = classes_to_use.index(cat["name"])
# Register dataset in metadata catalog
if dataset_name is not None:
# overwrite attrs
meta_dict = MetadataCatalog.get(dataset_name).as_dict()
meta_dict['thing_classes'] = classes_to_use
meta_dict['thing_dataset_id_to_contiguous_id'] = id_map
# update MetadataCatalog (cannot change inplace, has to remove)
MetadataCatalog.remove(dataset_name)
MetadataCatalog.get(dataset_name).set(**meta_dict)
# assert the change
assert MetadataCatalog.get(dataset_name).thing_classes == classes_to_use
# sort indices for reproducible results
img_ids = sorted(coco_api.imgs.keys())
imgs = coco_api.loadImgs(img_ids)
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
logger.info("Loaded {} images from {}".format(len(imgs), json_file))
# Return the coco converted to record list
return convert_to_dict_list(image_root, id_map, imgs, anns, dataset_name)
if __name__ == "__main__":
"""
Test the COCO json dataset loader.
Usage:
python -m detectron2.data.datasets.coco \
path/to/json path/to/image_root dataset_name
"""
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
import cv2
import sys
logger = setup_logger(name=__name__)
meta = MetadataCatalog.get(sys.argv[3])
dicts = extended_coco_load(sys.argv[1], sys.argv[2], sys.argv[3], ["cat", "dog"])
logger.info("Done loading {} samples.".format(len(dicts)))
for d in dicts:
img = cv2.imread(d["file_name"])[:, :, ::-1]
visualizer = Visualizer(img, metadata=meta)
vis = visualizer.draw_dataset_dict(d)
fpath = os.path.join("coco-data-vis", os.path.basename(d["file_name"]))
vis.save(fpath)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import os
from detectron2.data import MetadataCatalog
from detectron2.structures import BoxMode
from fvcore.common.timer import Timer
from .extended_coco import _cache_json_file
"""
This file contains functions to parse LVIS-format annotations into dicts in the
"Detectron2 format".
"""
logger = logging.getLogger(__name__)
def extended_lvis_load(json_file, image_root, dataset_name=None):
"""
Load a json file in LVIS's annotation format.
Args:
json_file (str): full path to the LVIS json annotation file.
image_root (str): the directory where the images in this json file exists.
dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train").
If provided, this function will put "thing_classes" into the metadata
associated with this dataset.
Returns:
list[dict]: a list of dicts in "Detectron2 Dataset" format. (See DATASETS.md)
Notes:
1. This function does not read the image files.
The results do not have the "image" field.
"""
from lvis import LVIS
json_file = _cache_json_file(json_file)
timer = Timer()
lvis_api = LVIS(json_file)
if timer.seconds() > 1:
logger.info(
"Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())
)
# sort indices for reproducible results
img_ids = sorted(list(lvis_api.imgs.keys()))
# imgs is a list of dicts, each looks something like:
# {'license': 4,
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
# 'file_name': 'COCO_val2014_000000001268.jpg',
# 'height': 427,
# 'width': 640,
# 'date_captured': '2013-11-17 05:57:24',
# 'id': 1268}
imgs = lvis_api.load_imgs(img_ids)
# anns is a list[list[dict]], where each dict is an annotation
# record for an object. The inner list enumerates the objects in an image
# and the outer list enumerates over images. Example of anns[0]:
# [{'segmentation': [[192.81,
# 247.09,
# ...
# 219.03,
# 249.06]],
# 'area': 1035.749,
# 'image_id': 1268,
# 'bbox': [192.81, 224.8, 74.73, 33.43],
# 'category_id': 16,
# 'id': 42986},
# ...]
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
# Sanity check that each annotation has a unique id
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
assert len(set(ann_ids)) == len(
ann_ids
), "Annotation ids in '{}' are not unique".format(json_file)
imgs_anns = list(zip(imgs, anns))
logger.info(
"Loaded {} images in the LVIS format from {}".format(len(imgs_anns), json_file)
)
dataset_dicts = []
count_ignore_image_root_warning = 0
for (img_dict, anno_dict_list) in imgs_anns:
record = {}
if "://" not in img_dict["file_name"]:
file_name = img_dict["file_name"]
if img_dict["file_name"].startswith("COCO"):
# Convert form the COCO 2014 file naming convention of
# COCO_[train/val/test]2014_000000000000.jpg to the 2017 naming
# convention of 000000000000.jpg (LVIS v1 will fix this naming issue)
file_name = file_name[-16:]
record["file_name"] = os.path.join(image_root, file_name)
else:
if image_root is not None:
count_ignore_image_root_warning += 1
if count_ignore_image_root_warning == 1:
logger.warning(
(
"Found '://' in file_name: {}, ignore image_root: {}"
"(logged once per dataset)."
).format(img_dict["file_name"], image_root)
)
record["file_name"] = img_dict["file_name"]
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
record["not_exhaustive_category_ids"] = img_dict.get(
"not_exhaustive_category_ids", []
)
record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
image_id = record["image_id"] = img_dict["id"]
objs = []
for anno in anno_dict_list:
# Check that the image_id in this annotation is the same as
# the image_id we're looking at.
# Fails only when the data parsing logic or the annotation file is buggy.
assert anno["image_id"] == image_id
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
obj["category_id"] = (
anno["category_id"] - 1
) # Convert 1-indexed to 0-indexed
segm = anno["segmentation"]
# filter out invalid polygons (< 3 points)
valid_segm = [
poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6
]
assert len(segm) == len(
valid_segm
), "Annotation contains an invalid polygon with < 3 points"
assert len(segm) > 0
obj["segmentation"] = segm
objs.append(obj)
record["annotations"] = objs
dataset_dicts.append(record)
if dataset_name:
meta = MetadataCatalog.get(dataset_name)
meta.thing_classes = get_extended_lvis_instances_meta(lvis_api)["thing_classes"]
return dataset_dicts
def get_extended_lvis_instances_meta(lvis_api):
cat_ids = lvis_api.get_cat_ids()
categories = lvis_api.load_cats(cat_ids)
assert min(cat_ids) == 1 and max(cat_ids) == len(
cat_ids
), "Category ids are not in [1, #categories], as expected"
extended_lvis_categories = [k for k in sorted(categories, key=lambda x: x["id"])]
thing_classes = [k["name"] for k in extended_lvis_categories]
meta = {"thing_classes": thing_classes}
return meta
if __name__ == "__main__":
"""
Test the LVIS json dataset loader.
Usage:
python -m detectron2.data.datasets.lvis \
path/to/json path/to/image_root dataset_name vis_limit
"""
import sys
import detectron2.data.datasets # noqa # add pre-defined metadata
import numpy as np
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
from PIL import Image
logger = setup_logger(name=__name__)
meta = MetadataCatalog.get(sys.argv[3])
dicts = extended_lvis_load(sys.argv[1], sys.argv[2], sys.argv[3])
logger.info("Done loading {} samples.".format(len(dicts)))
dirname = "lvis-data-vis"
os.makedirs(dirname, exist_ok=True)
for d in dicts[: int(sys.argv[4])]:
img = np.array(Image.open(d["file_name"]))
visualizer = Visualizer(img, metadata=meta)
vis = visualizer.draw_dataset_dict(d)
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
vis.save(fpath)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
This file contains utilities to load GANs datasets.
Similar to how COCO dataset is represented in Detectron2, a GANs dataset is represented
as a list of dicts, where each dict is in "standard dataset dict" format, which contains
raw data with fields such as:
- input_path (str): filename of input image
- fg_path (str): filename to the GT
...
"""
import os
import json
import logging
from fvcore.common.file_io import PathManager
from detectron2.data import DatasetCatalog, MetadataCatalog
logger = logging.getLogger(__name__)
IMG_EXTENSIONS = ['.jpg', '.JPG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def load_pix2pix_image_folder(image_root, input_folder="input", gt_folder="gt"):
"""
Args:
image_root (str): the directory where the images exist.
gt_postfix (str): the postfix for the ground truth images
Returns:
list[dict]: a list of dicts in argos' "standard dataset dict" format
"""
data = []
# gt_postfix = "%s." % (gt_postfix)
input_root = os.path.join(image_root, input_folder)
for root, _, fnames in sorted(os.walk(input_root)):
for fname in sorted(fnames):
if is_image_file(fname):
gt_fname = fname.replace("/%s/" % (gt_folder))
input_path = os.path.join(root, fname)
gt_path = os.path.join(root, gt_fname)
if not os.path.isfile(gt_path):
logger.warning("{} is not exist".format(gt_fname))
continue
# if len(gt_postfix) > 1 and fname.rfind(gt_postfix) != -1: # skip GT file
# continue
# gt_fname = fname[:-4] + gt_postfix + fname[-3:]
# assert gt_fname in fnames, (
# "gt file %s is not exist in %s" % (gt_fname, root))
f = {
"file_name": fname[:-4],
"input_path": input_path,
"gt_path": gt_path
}
data.append(f)
if image_root.rfind("test") != -1 and len(data) == 5000:
logger.info("Reach maxinum of test data: {} ".format(len(data)))
return data
logger.info("Total number of data dicts: {} ".format(len(data)))
return data
def load_pix2pix_json(
json_path, input_folder, gt_folder, mask_folder,
real_json_path=None, real_folder=None, max_num=1e10,
):
"""
Args:
json_path (str): the directory where the json file exists which saves the filenames and labels.
input_folder (str): the directory for the input/source images
input_folder (str): the directory for the ground_truth/target images
mask_folder (str): the directory for the masks
Returns:
list[dict]: a list of dicts
"""
real_filenames = {}
if real_json_path is not None:
with PathManager.open(real_json_path, 'r') as f:
real_filenames = json.load(f)
data = []
with PathManager.open(json_path, 'r') as f:
filenames = json.load(f)
in_len = len(filenames)
real_len = len(real_filenames)
total_len = min(max(in_len, real_len), max_num)
real_keys = [*real_filenames.keys()]
in_keys = [*filenames.keys()]
cnt = 0
# for fname in filenames.keys():
while cnt < total_len:
fname = in_keys[cnt % in_len]
f = {
"file_name": fname,
"input_folder": input_folder,
"gt_folder": gt_folder,
"mask_folder": mask_folder,
"input_label": filenames[fname],
"real_folder": real_folder
}
if real_len > 0:
real_fname = real_keys[cnt % real_len]
f["real_file_name"] = real_fname
data.append(f)
cnt += 1
# 5000 is the general number of images used to calculate FID in GANs
# if max_num > 0 and len(data) == max_num:
# logger.info("Reach maxinum of test data: {} ".format(len(data)))
# return data
logger.info("Total number of data dicts: {} ".format(len(data)))
return data
def register_folder_dataset(
name,
json_path,
input_folder,
gt_folder=None,
mask_folder=None,
input_src_path=None,
gt_src_path=None,
mask_src_path=None,
real_json_path=None,
real_folder=None,
real_src_path=None,
max_num=1e10,
):
DatasetCatalog.register(
name, lambda: load_pix2pix_json(
json_path, input_folder, gt_folder, mask_folder,
real_json_path, real_folder, max_num
)
)
metadata = {
"input_src_path": input_src_path,
"gt_src_path": gt_src_path,
"mask_src_path": mask_src_path,
"real_src_path": real_src_path,
"input_folder": input_folder,
"gt_folder": gt_folder,
"mask_folder": mask_folder,
"real_folder": real_folder,
}
MetadataCatalog.get(name).set(**metadata)
def load_lmdb_keys(max_num):
"""
Args:
max_num (str): the total number of
Returns:
list[dict]: a list of dicts
"""
data = []
for i in range(max_num):
f = {"index": i}
data.append(f)
logger.info("Total number of data dicts: {} ".format(len(data)))
return data
def register_lmdb_dataset(
name,
data_folder,
src_data_folder,
max_num,
):
DatasetCatalog.register(
name, lambda: load_lmdb_keys(max_num)
)
metadata = {
"data_folder": data_folder,
"src_data_folder": src_data_folder,
"max_num": max_num,
}
MetadataCatalog.get(name).set(**metadata)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# import all modules to make sure Registry works
from . import affine, blur, color_yuv, crop, d2_native # noqa
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import random
import cv2
import json
import numpy as np
from .build import TRANSFORM_OP_REGISTRY
from detectron2.data.transforms import Transform, TransformGen, NoOpTransform
import torchvision.transforms as T
class AffineTransform(Transform):
def __init__(self, M, img_w, img_h, flags=None, border_mode=None, is_inversed_M=False):
"""
Args:
will transform img according to affine transform M
"""
super().__init__()
self._set_attributes(locals())
self.warp_kwargs = {}
if flags is not None:
self.warp_kwargs["flags"] = flags
if border_mode is not None:
self.warp_kwargs["borderMode"] = border_mode
def apply_image(self, img):
M = self.M
if self.is_inversed_M:
M = M[:2]
img = cv2.warpAffine(
img,
M,
(int(self.img_w), (self.img_h)),
**self.warp_kwargs,
)
return img
def apply_coords(self, coords):
# Add row of ones to enable matrix multiplication
coords = coords.T
ones = np.ones((1, coords.shape[1]))
coords = np.vstack((coords, ones))
M = self.M
if self.is_inversed_M:
M = np.linalg.inv(M)
coords = (M @ coords)[:2, :].T
return coords
class RandomPivotScaling(TransformGen):
"""
Uniformly pick a random pivot point inside image frame, scaling the image
around the pivot point using the scale factor sampled from a list of
given scales. The pivot point's location is unchanged after the transform.
Arguments:
scales: List[float]: each element can be any positive float number,
when larger than 1.0 objects become larger after transform
and vice versa.
"""
def __init__(self, scales):
super().__init__()
self._init(locals())
self.scales = scales
def get_transform(self, img):
img_h, img_w, _ = img.shape
img_h = float(img_h)
img_w = float(img_w)
pivot_y = self._rand_range(0.0, img_h)
pivot_x = self._rand_range(0.0, img_w)
def _interp(p1, p2, alpha):
dx = p2[0] - p1[0]
dy = p2[1] - p1[1]
p_x = p1[0] + alpha * dx
p_y = p1[1] + alpha * dy
return (p_x, p_y)
scale = np.random.choice(self.scales)
lt = (0.0, 0.0)
rb = (img_w, img_h)
pivot = (pivot_x, pivot_y)
pts1 = np.float32([lt, pivot, rb])
pts2 = np.float32([
_interp(pivot, lt, scale),
pivot,
_interp(pivot, rb, scale)],
)
M = cv2.getAffineTransform(pts1, pts2)
return AffineTransform(M, img_w, img_h)
class RandomAffine(TransformGen):
"""
Apply random affine trasform to the image given
probabilities and ranges in each dimension.
"""
def __init__(
self,
prob=0.5,
angle_range=(-90, 90),
translation_range=(0, 0),
scale_range=(1.0, 1.0),
shear_range=(0, 0),
):
"""
Args:
prob (float): probability of applying transform.
angle_range (tuple of integers): min/max rotation angle in degrees
between -180 and 180.
translation_range (tuple of integers): min/max translation
(post re-centered rotation).
scale_range (tuple of floats): min/max scale (post re-centered rotation).
shear_range (tuple of intgers): min/max shear angle value in degrees
between -180 to 180.
"""
super().__init__()
# Turn all locals into member variables.
self._init(locals())
def get_transform(self, img):
im_h, im_w = img.shape[:2]
max_size = max(im_w, im_h)
center = [im_w / 2, im_h / 2]
angle = random.uniform(self.angle_range[0], self.angle_range[1])
translation = [
random.uniform(self.translation_range[0], self.translation_range[1]),
random.uniform(self.translation_range[0], self.translation_range[1]),
]
scale = random.uniform(self.scale_range[0], self.scale_range[1])
shear = [
random.uniform(self.shear_range[0], self.shear_range[1]),
random.uniform(self.shear_range[0], self.shear_range[1]),
]
dummy_translation = [0.0, 0.0]
dummy_scale = 1.0
M_inv = T.functional._get_inverse_affine_matrix(
center, angle, dummy_translation, dummy_scale, shear
)
M_inv.extend([0.0, 0.0, 1.0])
M_inv = np.array(M_inv).reshape((3, 3))
M = np.linalg.inv(M_inv)
# Center in output patch
img_corners = np.array([
[0, 0, im_w, im_w],
[0, im_h, 0, im_h],
[1, 1, 1, 1],
])
transformed_corners = M @ img_corners
x_min = np.amin(transformed_corners[0])
x_max = np.amax(transformed_corners[0])
x_range = np.ceil(x_max - x_min)
y_min = np.amin(transformed_corners[1])
y_max = np.amax(transformed_corners[1])
y_range = np.ceil(y_max - y_min)
# Apply translation and scale after centering in output patch
translation_adjustment = [(max_size - im_w) / 2, (max_size - im_h) / 2]
translation[0] += translation_adjustment[0]
translation[1] += translation_adjustment[1]
scale_adjustment = min(max_size / x_range, max_size / y_range)
scale *= scale_adjustment
M_inv = T.functional._get_inverse_affine_matrix(
center, angle, translation, scale, shear
)
# Convert to Numpy matrix so it can be inverted
M_inv.extend([0.0, 0.0, 1.0])
M_inv = np.array(M_inv).reshape((3, 3))
M = np.linalg.inv(M_inv)
do = self._rand_range() < self.prob
if do:
return AffineTransform(
M_inv,
max_size,
max_size,
flags=cv2.WARP_INVERSE_MAP + cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REPLICATE,
is_inversed_M=True
)
else:
return NoOpTransform()
# example repr: "RandomPivotScalingOp::[1.0, 0.75, 0.5]"
@TRANSFORM_OP_REGISTRY.register()
def RandomPivotScalingOp(cfg, arg_str, is_train):
assert is_train
scales = json.loads(arg_str)
assert isinstance(scales, list)
assert all(isinstance(scale, (float, int)) for scale in scales)
return [RandomPivotScaling(scales=scales)]
@TRANSFORM_OP_REGISTRY.register()
def RandomAffineOp(cfg, arg_str, is_train):
assert is_train
kwargs = json.loads(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [RandomAffine(**kwargs)]
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import detectron2.data.transforms.augmentation as aug
from detectron2.data.transforms import NoOpTransform, Transform
import numpy as np
from .build import TRANSFORM_OP_REGISTRY, _json_load
class LocalizedBoxMotionBlurTransform(Transform):
""" Transform to blur provided bounding boxes from an image. """
def __init__(self, bounding_boxes, k=(7, 15), angle=(0, 360), direction=(-1.0, 1.0)):
import imgaug.augmenters as iaa
super().__init__()
self._set_attributes(locals())
self.aug = iaa.MotionBlur(k, angle, direction, 1)
def apply_image(self, img):
bbox_regions = [img[y:y+h, x:x+w] for x, y, w, h in self.bounding_boxes]
blurred_boxes = self.aug.augment_images(bbox_regions)
new_img = np.array(img)
for (x, y, w, h), blurred in zip(self.bounding_boxes, blurred_boxes):
new_img[y:y+h, x:x+w] = blurred
return new_img
def apply_segmentation(self, segmentation):
""" Apply no transform on the full-image segmentation. """
return segmentation
def apply_coords(self, coords):
""" Apply no transform on the coordinates. """
return coords
def inverse(self) -> Transform:
""" The inverse is a No-op, only for geometric transforms. """
return NoOpTransform()
class LocalizedBoxMotionBlur(aug.Augmentation):
"""
Performs faked motion blur on bounding box annotations in an image.
Randomly selects motion blur parameters from the ranges `k`, `angle`, `direction`.
"""
def __init__(self, prob=0.5, k=(7, 15), angle=(0, 360), direction=(-1.0, 1.0)):
super().__init__()
self._init(locals())
def _validate_bbox_xywh_within_bounds(self, bbox, img_h, img_w):
x, y, w, h = bbox
assert x >= 0, f"Invalid x {x}"
assert y >= 0, f"Invalid y {x}"
assert y+h <= img_h, f"Invalid right {x+w} (img width {img_w})"
assert y+h <= img_h, f"Invalid bottom {y+h} (img height {img_h})"
def get_transform(self, image, annotations):
do_tfm = self._rand_range() < self.prob
if do_tfm:
return self._get_blur_transform(image, annotations)
else:
return NoOpTransform()
def _get_blur_transform(self, image, annotations):
"""
Return a `Transform` that simulates motion blur within the image's bounding box regions.
"""
img_h, img_w = image.shape[:2]
bboxes = [ann["bbox"] for ann in annotations]
# Debug
for bbox in bboxes:
self._validate_bbox_xywh_within_bounds(bbox, img_h, img_w)
return LocalizedBoxMotionBlurTransform(
bboxes,
k=self.k,
angle=self.angle,
direction=self.direction,
)
# example repr: "LocalizedBoxMotionBlurOp::{'prob': 0.5, 'k': [3,7], 'angle': [0, 360]}"
@TRANSFORM_OP_REGISTRY.register()
def RandomLocalizedBoxMotionBlurOp(cfg, arg_str, is_train):
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [LocalizedBoxMotionBlur(**kwargs)]
class MotionBlurTransform(Transform):
def __init__(self, k=(7, 15), angle=(0, 360), direction=(-1.0, 1.0)):
"""
Args:
will apply the specified blur to the image
"""
import imgaug.augmenters as iaa
super().__init__()
self._set_attributes(locals())
self.aug = iaa.MotionBlur(k, angle, direction, 1)
def apply_image(self, img):
img = self.aug.augment_image(img)
return img
def apply_segmentation(self, segmentation):
return segmentation
def apply_coords(self, coords):
return coords
class RandomMotionBlur(aug.Augmentation):
"""
Apply random motion blur.
"""
def __init__(self, prob=0.5, k=(3, 7), angle=(0, 360), direction=(-1.0, 1.0)):
"""
Args:
prob (float): probability of applying transform
k (tuple): refer to `iaa.MotionBlur`
angle (tuple): refer to `iaa.MotionBlur`
direction (tuple): refer to `iaa.MotionBlur`
"""
super().__init__()
# Turn all locals into member variables.
self._init(locals())
def get_transform(self, img):
do = self._rand_range() < self.prob
if do:
return MotionBlurTransform(self.k, self.angle, self.direction)
else:
return NoOpTransform()
# example repr: "RandomMotionBlurOp::{'prob': 0.5, 'k': [3,7], 'angle': [0, 360]}"
@TRANSFORM_OP_REGISTRY.register()
def RandomMotionBlurOp(cfg, arg_str, is_train):
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [RandomMotionBlur(**kwargs)]
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
import torch
from detectron2.structures.boxes import Boxes
def get_box_union(boxes: Boxes):
""" Merge all boxes into a single box """
if len(boxes) == 0:
return boxes
bt = boxes.tensor
union_bt = torch.cat(
(torch.min(bt[:, :2], 0).values, torch.max(bt[:, 2:], 0).values)
).reshape(1, -1)
return Boxes(union_bt)
def get_box_from_mask(mask: np.ndarray):
"""Find if there are non-zero elements per row/column first and then find
min/max position of those elements.
Only support 2d image (h x w)
Return (x1, y1, w, h) if bbox found, otherwise None
"""
assert len(mask.shape) == 2, f"Invalid shape {mask.shape}"
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if bool(np.any(rows)) is False or bool(np.any(cols)) is False:
return None
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
assert cmax >= cmin, f"cmax={cmax}, cmin={cmin}"
assert rmax >= rmin, f"rmax={rmax}, rmin={rmin}"
# x1, y1, w, h
return cmin, rmin, cmax - cmin + 1, rmax - rmin + 1
def get_min_box_aspect_ratio(bbox_xywh, target_aspect_ratio):
"""Get a minimal bbox that matches the target_aspect_ratio
target_aspect_ratio is representation by w/h
bbox are represented by pixel coordinates"""
bbox_xywh = torch.Tensor(bbox_xywh)
box_w, box_h = bbox_xywh[2:]
box_ar = float(box_w) / box_h
if box_ar >= target_aspect_ratio:
new_w = box_w
new_h = float(new_w) / target_aspect_ratio
else:
new_h = box_h
new_w = new_h * target_aspect_ratio
new_wh = torch.Tensor([new_w, new_h])
bbox_center = bbox_xywh[:2] + bbox_xywh[2:] / 2.0
new_xy = bbox_center - new_wh / 2.0
return torch.cat([new_xy, new_wh])
def get_box_center(bbox_xywh):
"""Get the center of the bbox"""
return torch.Tensor(bbox_xywh[:2]) + torch.Tensor(bbox_xywh[2:]) / 2.0
def get_bbox_xywh_from_center_wh(bbox_center, bbox_wh):
"""Get a bbox from bbox center and the width and height"""
bbox_wh = torch.Tensor(bbox_wh)
bbox_xy = torch.Tensor(bbox_center) - bbox_wh / 2.0
return torch.cat([bbox_xy, bbox_wh])
def get_bbox_xyxy_from_xywh(bbox_xywh):
"""Convert the bbox from xywh format to xyxy format
bbox are represented by pixel coordinates,
the center of pixels are (x + 0.5, y + 0.5)
"""
return torch.Tensor(
[
bbox_xywh[0],
bbox_xywh[1],
bbox_xywh[0] + bbox_xywh[2],
bbox_xywh[1] + bbox_xywh[3],
]
)
def get_bbox_xywh_from_xyxy(bbox_xyxy):
"""Convert the bbox from xyxy format to xywh format"""
return torch.Tensor(
[
bbox_xyxy[0],
bbox_xyxy[1],
bbox_xyxy[2] - bbox_xyxy[0],
bbox_xyxy[3] - bbox_xyxy[1],
]
)
def to_boxes_from_xywh(bbox_xywh):
return Boxes(get_bbox_xyxy_from_xywh(bbox_xywh).unsqueeze(0))
def scale_bbox_center(bbox_xywh, target_scale):
"""Scale the bbox around the center of the bbox"""
box_center = get_box_center(bbox_xywh)
box_wh = torch.Tensor(bbox_xywh[2:]) * target_scale
return get_bbox_xywh_from_center_wh(box_center, box_wh)
def offset_bbox(bbox_xywh, target_offset):
"""Offset the bbox based on target_offset"""
box_center = get_box_center(bbox_xywh)
new_center = box_center + torch.Tensor(target_offset)
return get_bbox_xywh_from_center_wh(new_center, bbox_xywh[2:])
def clip_box_xywh(bbox_xywh, image_size_hw):
"""Clip the bbox based on image_size_hw"""
h, w = image_size_hw
bbox_xyxy = get_bbox_xyxy_from_xywh(bbox_xywh)
bbox_xyxy[0] = max(bbox_xyxy[0], 0)
bbox_xyxy[1] = max(bbox_xyxy[1], 0)
bbox_xyxy[2] = min(bbox_xyxy[2], w)
bbox_xyxy[3] = min(bbox_xyxy[3], h)
return get_bbox_xywh_from_xyxy(bbox_xyxy)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import logging
from detectron2.data import transforms as d2T
from detectron2.utils.registry import Registry
logger = logging.getLogger(__name__)
TRANSFORM_OP_REGISTRY = Registry("D2GO_TRANSFORM_REGISTRY")
def _json_load(arg_str):
try:
return json.loads(arg_str)
except json.decoder.JSONDecodeError as e:
logger.warning("Can't load arg_str: {}".format(arg_str))
raise e
# example repr: "ResizeShortestEdgeOp"
@TRANSFORM_OP_REGISTRY.register()
def ResizeShortestEdgeOp(cfg, arg_str, is_train):
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
else:
min_size = cfg.INPUT.MIN_SIZE_TEST
max_size = cfg.INPUT.MAX_SIZE_TEST
sample_style = "choice"
if sample_style == "range":
assert (
len(min_size) == 2
), "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))
tfm_gens = []
if not min_size == 0: # set to zero to disable resize
tfm_gens.append(d2T.ResizeShortestEdge(min_size, max_size, sample_style))
return tfm_gens
# example repr: "ResizeShortestEdgeSquareOp"
@TRANSFORM_OP_REGISTRY.register()
def ResizeShortestEdgeSquareOp(cfg, arg_str, is_train):
""" Resize the input to square using INPUT.MIN_SIZE_TRAIN or INPUT.MIN_SIZE_TEST
without keeping aspect ratio
"""
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
assert (
isinstance(min_size, (list, tuple)) and len(min_size) == 1
), "Only a signle size is supported"
min_size = min_size[0]
else:
min_size = cfg.INPUT.MIN_SIZE_TEST
tfm_gens = []
if not min_size == 0: # set to zero to disable resize
tfm_gens.append(d2T.Resize(shape=[min_size, min_size]))
return tfm_gens
@TRANSFORM_OP_REGISTRY.register()
def ResizeOp(cfg, arg_str, is_train):
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [d2T.Resize(**kwargs)]
_TRANSFORM_REPR_SEPARATOR = "::"
def parse_tfm_gen_repr(tfm_gen_repr):
if tfm_gen_repr.count(_TRANSFORM_REPR_SEPARATOR) == 0:
return tfm_gen_repr, None
elif tfm_gen_repr.count(_TRANSFORM_REPR_SEPARATOR) == 1:
return tfm_gen_repr.split(_TRANSFORM_REPR_SEPARATOR)
else:
raise ValueError(
"Can't to parse transform repr name because of multiple separator found."
" Offending name: {}"
)
def build_transform_gen(cfg, is_train):
"""
This function builds a list of TransformGen or Transform objects using the a list of
strings from cfg.D2GO_DATA.AUG_OPS.TRAIN/TEST. Each string (aka. `tfm_gen_repr`)
will be split into `name` and `arg_str` (separated by "::"); the `name`
will be used to lookup the registry while `arg_str` will be used as argument.
Each function in registry needs to take `cfg`, `arg_str` and `is_train` as
input, and return a list of TransformGen or Transform objects.
"""
tfm_gen_repr_list = (
cfg.D2GO_DATA.AUG_OPS.TRAIN if is_train else cfg.D2GO_DATA.AUG_OPS.TEST
)
tfm_gens = [
TRANSFORM_OP_REGISTRY.get(name)(cfg, arg_str, is_train)
for name, arg_str in [
parse_tfm_gen_repr(tfm_gen_repr) for tfm_gen_repr in tfm_gen_repr_list
]
]
assert all(isinstance(gens, list) for gens in tfm_gens)
tfm_gens = [gen for gens in tfm_gens for gen in gens]
assert all(isinstance(gen, (d2T.Transform, d2T.TransformGen)) for gen in tfm_gens)
return tfm_gens
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List
import detectron2.data.transforms.augmentation as aug
import numpy as np
from detectron2.config import CfgNode
from detectron2.data import detection_utils as du
from detectron2.data.transforms.transform import Transform
from fvcore.transforms.transform import BlendTransform
from .build import TRANSFORM_OP_REGISTRY, _json_load
class InvertibleColorTransform(Transform):
"""
Generic wrapper for invertible photometric transforms.
These transformations should only affect the color space and
not the coordinate space of the image (e.g. annotation
coordinates such as bounding boxes should not be changed)
"""
def __init__(self, op, inverse_op):
"""
Args:
op (Callable): operation to be applied to the image,
which takes in an ndarray and returns an ndarray.
"""
if not callable(op):
raise ValueError("op parameter should be callable")
if not callable(inverse_op):
raise ValueError("inverse_op parameter should be callable")
super().__init__()
self._set_attributes(locals())
def apply_image(self, img):
return self.op(img)
def apply_coords(self, coords):
return coords
def inverse(self):
return InvertibleColorTransform(self.inverse_op, self.op)
def apply_segmentation(self, segmentation):
return segmentation
class RandomContrastYUV(aug.Augmentation):
"""
Randomly transforms contrast for images in YUV format.
See similar:
detectron2.data.transforms.RandomContrast,
detectron2.data.transforms.RandomBrightness
"""
def __init__(self, intensity_min: float, intensity_max: float):
super().__init__()
self._init(locals())
def get_transform(self, img):
w = np.random.uniform(self.intensity_min, self.intensity_max)
pure_gray = np.zeros_like(img)
pure_gray[:, :, 0] = 0.5
return BlendTransform(src_image=pure_gray, src_weight=1 - w, dst_weight=w)
class RandomSaturationYUV(aug.Augmentation):
"""
Randomly transforms saturation for images in YUV format.
See similar: detectron2.data.transforms.RandomSaturation
"""
def __init__(self, intensity_min: float, intensity_max: float):
super().__init__()
self._init(locals())
def get_transform(self, img):
assert (
len(img.shape) == 3 and img.shape[-1] == 3
), f"Expected (H, W, 3), image shape {img.shape}"
w = np.random.uniform(self.intensity_min, self.intensity_max)
grayscale = np.zeros_like(img)
grayscale[:, :, 0] = img[:, :, 0]
return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
def convert_rgb_to_yuv_bt601(image):
"""Convert RGB image in (H, W, C) to YUV format
image: range 0 ~ 255
"""
image = image / 255.0
image = np.dot(image, np.array(du._M_RGB2YUV).T)
return image
def convery_yuv_bt601_to_rgb(image):
return du.convert_image_to_rgb(image, "YUV-BT.601")
class RGB2YUVBT601(aug.Augmentation):
def __init__(self):
super().__init__()
self.trans = InvertibleColorTransform(
convert_rgb_to_yuv_bt601, convery_yuv_bt601_to_rgb
)
def get_transform(self, image):
return self.trans
class YUVBT6012RGB(aug.Augmentation):
def __init__(self):
super().__init__()
self.trans = InvertibleColorTransform(
convery_yuv_bt601_to_rgb, convert_rgb_to_yuv_bt601
)
def get_transform(self, image):
return self.trans
def build_func(cfg: CfgNode, arg_str: str, is_train: bool, obj) -> List[aug.Augmentation]:
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [obj(**kwargs)]
@TRANSFORM_OP_REGISTRY.register()
def RandomContrastYUVOp(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, obj=RandomContrastYUV)
@TRANSFORM_OP_REGISTRY.register()
def RandomSaturationYUVOp(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, obj=RandomSaturationYUV)
@TRANSFORM_OP_REGISTRY.register()
def RGB2YUVBT601Op(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, obj=RGB2YUVBT601)
@TRANSFORM_OP_REGISTRY.register()
def YUVBT6012RGBOp(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, obj=YUVBT6012RGB)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
import detectron2.data.transforms.augmentation as aug
import numpy as np
from detectron2.data.transforms import ExtentTransform, CropTransform
from detectron2.structures import BoxMode
from . import box_utils as bu
from fvcore.transforms.transform import CropTransform, NoOpTransform, Transform
from .build import TRANSFORM_OP_REGISTRY, _json_load
class CropBoundary(aug.Augmentation):
"""Crop the boundary of the image by `count` pixel on each side"""
def __init__(self, count=3):
super().__init__()
self.count = count
def get_transform(self, image):
img_h, img_w = image.shape[:2]
assert self.count < img_h and self.count < img_w
assert img_h > self.count * 2
assert img_w > self.count * 2
box = [self.count, self.count, img_w - self.count * 2, img_h - self.count * 2]
return CropTransform(*box)
class PadTransform(Transform):
def __init__(self, x0, y0, w, h, org_w, org_h, pad_mode="constant"):
super().__init__()
assert x0 + w <= org_w
assert y0 + h <= org_h
self._set_attributes(locals())
def apply_image(self, img):
"""img: HxWxC or HxW"""
assert len(img.shape) == 2 or len(img.shape) == 3
assert img.shape[0] == self.h and img.shape[1] == self.w
pad_width = [
(self.y0, self.org_h - self.h - self.y0),
(self.x0, self.org_w - self.w - self.x0),
*([(0, 0)] if len(img.shape) == 3 else []),
]
pad_args = {"mode": self.pad_mode}
if self.pad_mode == "constant":
pad_args["constant_values"] = 0
ret = np.pad(img, pad_width=tuple(pad_width), **pad_args)
return ret
def apply_coords(self, coords: np.ndarray) -> np.ndarray:
raise NotImplementedError()
def inverse(self) -> Transform:
return CropTransform(self.x0, self.y0, self.w, self.h, self.org_w, self.org_h)
InvertibleCropTransform = CropTransform
class PadBorderDivisible(aug.Augmentation):
def __init__(self, size_divisibility, pad_mode="constant"):
super().__init__()
self.size_divisibility = size_divisibility
self.pad_mode = pad_mode
def get_transform(self, image):
""" image: HxWxC """
assert len(image.shape) == 3 and image.shape[2] in [1, 3]
H, W = image.shape[:2]
new_h = int(math.ceil(H / self.size_divisibility) * self.size_divisibility)
new_w = int(math.ceil(W / self.size_divisibility) * self.size_divisibility)
return PadTransform(0, 0, W, H, new_w, new_h, pad_mode=self.pad_mode)
class RandomCropFixedAspectRatio(aug.Augmentation):
def __init__(
self, crop_aspect_ratios_list, scale_range=None, offset_scale_range=None
):
super().__init__()
assert isinstance(crop_aspect_ratios_list, (list, tuple))
assert (
scale_range is None
or isinstance(scale_range, (list, tuple))
and len(scale_range) == 2
)
assert (
offset_scale_range is None
or isinstance(offset_scale_range, (list, tuple))
and len(offset_scale_range) == 2
)
# [w1/h1, w2/h2, ...]
self.crop_aspect_ratios_list = crop_aspect_ratios_list
# [low, high] or None
self.scale_range = scale_range
# [low, high] or None
self.offset_scale_range = offset_scale_range
self.rng = np.random.default_rng()
def _pick_aspect_ratio(self):
return self.rng.choice(self.crop_aspect_ratios_list)
def _pick_scale(self):
if self.scale_range is None:
return 1.0
return self.rng.uniform(*self.scale_range)
def _pick_offset(self, box_w, box_h):
if self.offset_scale_range is None:
return [0, 0]
offset_scale = self.rng.uniform(*self.offset_scale_range, size=2)
return offset_scale[0] * box_w, offset_scale[1] * box_h
def get_transform(self, image, sem_seg):
# HWC or HW for image, HW for sem_seg
assert len(image.shape) in [2, 3]
assert len(sem_seg.shape) == 2
mask_box_xywh = bu.get_box_from_mask(sem_seg)
# do nothing if the mask is empty (the whole image is background)
if mask_box_xywh is None:
return NoOpTransform()
crop_ar = self._pick_aspect_ratio()
target_scale = self._pick_scale()
target_offset = self._pick_offset(*mask_box_xywh[2:])
mask_box_xywh = bu.offset_bbox(mask_box_xywh, target_offset)
mask_box_xywh = bu.scale_bbox_center(mask_box_xywh, target_scale)
target_box_xywh = bu.get_min_box_aspect_ratio(mask_box_xywh, crop_ar)
target_bbox_xyxy = bu.get_bbox_xyxy_from_xywh(target_box_xywh)
return ExtentTransform(
src_rect=target_bbox_xyxy,
output_size=(
int(target_box_xywh[3].item()),
int(target_box_xywh[2].item()),
),
)
# example repr: "CropBoundaryOp::{'count': 3}"
@TRANSFORM_OP_REGISTRY.register()
def CropBoundaryOp(cfg, arg_str, is_train):
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [CropBoundary(**kwargs)]
# example repr: "RandomCropFixedAspectRatioOp::{'crop_aspect_ratios_list': [0.5], 'scale_range': [0.8, 1.2], 'offset_scale_range': [-0.3, 0.3]}"
@TRANSFORM_OP_REGISTRY.register()
def RandomCropFixedAspectRatioOp(cfg, arg_str, is_train):
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [RandomCropFixedAspectRatio(**kwargs)]
class RandomInstanceCrop(aug.Augmentation):
def __init__(self, crop_scale=(0.8, 1.6)):
"""
Generates a CropTransform centered around the instance.
crop_scale: [low, high] relative crop scale around the instance, this
determines how far to zoom in / out around the cropped instance
"""
super().__init__()
self.crop_scale = crop_scale
assert (
isinstance(crop_scale, (list, tuple)) and len(crop_scale) == 2
), crop_scale
def get_transform(self, image, annotations):
"""
This function will modify instances to set the iscrowd flag to 1 for
annotations not picked. It relies on the dataset mapper to filter those
items out
"""
assert isinstance(annotations, (list, tuple)), annotations
assert all("bbox" in x for x in annotations), annotations
assert all("bbox_mode" in x for x in annotations), annotations
image_size = image.shape[:2]
# filter out iscrowd
annotations = [x for x in annotations if x.get("iscrowd", 0) == 0]
if len(annotations) == 0:
return NoOpTransform()
sel_index = np.random.randint(len(annotations))
# set iscrowd flag of other annotations to 1 so that they will be
# filtered out by the datset mapper (https://fburl.com/diffusion/fg64cb4h)
for idx, instance in enumerate(annotations):
if idx != sel_index:
instance["iscrowd"] = 1
instance = annotations[sel_index]
bbox_xywh = BoxMode.convert(
instance["bbox"], instance["bbox_mode"], BoxMode.XYWH_ABS
)
scale = np.random.uniform(*self.crop_scale)
bbox_xywh = bu.scale_bbox_center(bbox_xywh, scale)
bbox_xywh = bu.clip_box_xywh(bbox_xywh, image_size).int()
return CropTransform(*bbox_xywh.tolist())
# example repr: "RandomInstanceCropOp::{'crop_scale': [0.8, 1.6]}"
@TRANSFORM_OP_REGISTRY.register()
def RandomInstanceCropOp(cfg, arg_str, is_train):
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [RandomInstanceCrop(**kwargs)]
class CropBoxAug(aug.Augmentation):
""" Augmentation to crop the image based on boxes
Scale the box with `box_scale_factor` around the center before cropping
"""
def __init__(self, box_scale_factor=1.0):
super().__init__()
self.box_scale_factor = box_scale_factor
def get_transform(self, image: np.ndarray, boxes: np.ndarray):
# boxes: 1 x 4 in xyxy format
assert boxes.shape[0] == 1
assert isinstance(image, np.ndarray)
assert isinstance(boxes, np.ndarray)
img_h, img_w = image.shape[0:2]
box_xywh = bu.get_bbox_xywh_from_xyxy(boxes[0])
if self.box_scale_factor != 1.0:
box_xywh = bu.scale_bbox_center(box_xywh, self.box_scale_factor)
box_xywh = bu.clip_box_xywh(box_xywh, [img_h, img_w])
box_xywh = box_xywh.int().tolist()
return CropTransform(*box_xywh, orig_w=img_w, orig_h=img_h)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from .build import TRANSFORM_OP_REGISTRY, _json_load
from detectron2.data import transforms as d2T
from detectron2.projects.point_rend import ColorAugSSDTransform
logger = logging.getLogger(__name__)
D2_RANDOM_TRANSFORMS = {
"RandomBrightness": d2T.RandomBrightness,
"RandomContrast": d2T.RandomContrast,
"RandomCrop": d2T.RandomCrop,
"RandomRotation": d2T.RandomRotation,
"RandomExtent": d2T.RandomExtent,
"RandomFlip": d2T.RandomFlip,
"RandomSaturation": d2T.RandomSaturation,
"RandomLighting": d2T.RandomLighting,
}
def build_func(cfg, arg_str, is_train, name):
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [D2_RANDOM_TRANSFORMS[name](**kwargs)]
# example 1: RandomFlipOp
# example 2: RandomFlipOp::{}
# example 3: RandomFlipOp::{"prob":0.5}
# example 4: RandomBrightnessOp::{"intensity_min":1.0, "intensity_max":2.0}
@TRANSFORM_OP_REGISTRY.register()
def RandomBrightnessOp(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, name="RandomBrightness")
@TRANSFORM_OP_REGISTRY.register()
def RandomContrastOp(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, name="RandomContrast")
@TRANSFORM_OP_REGISTRY.register()
def RandomCropOp(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, name="RandomCrop")
@TRANSFORM_OP_REGISTRY.register()
def RandomRotation(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, name="RandomRotation")
@TRANSFORM_OP_REGISTRY.register()
def RandomExtentOp(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, name="RandomExtent")
@TRANSFORM_OP_REGISTRY.register()
def RandomFlipOp(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, name="RandomFlip")
@TRANSFORM_OP_REGISTRY.register()
def RandomSaturationOp(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, name="RandomSaturation")
@TRANSFORM_OP_REGISTRY.register()
def RandomLightingOp(cfg, arg_str, is_train):
return build_func(cfg, arg_str, is_train, name="RandomLighting")
@TRANSFORM_OP_REGISTRY.register()
def RandomSSDColorAugOp(cfg, arg_str, is_train):
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
assert "img_format" not in kwargs
return [ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT, **kwargs)]
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Optional, Union
import numpy as np
import torch
from detectron2.data.transforms.augmentation import AugmentationList, Augmentation
from detectron2.structures import Boxes
from fvcore.transforms.transform import Transform, TransformList
class AugInput:
"""
Same as AugInput in vision/fair/detectron2/detectron2/data/transforms/augmentation.py
but allows torch.Tensor as input
"""
def __init__(
self,
image: Union[np.ndarray, torch.Tensor],
*,
boxes: Optional[Union[np.ndarray, torch.Tensor, Boxes]] = None,
sem_seg: Optional[Union[np.ndarray, torch.Tensor]] = None,
):
"""
Args:
image (ndarray/torch.Tensor): (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
floating point in range [0, 1] or [0, 255]. (C, H, W) for tensor.
boxes (ndarray or None): Nx4 float32 boxes in XYXY_ABS mode
sem_seg (ndarray or None): HxW uint8 semantic segmentation mask. Each element
is an integer label of pixel.
"""
self.image = image
self.boxes = boxes
self.sem_seg = sem_seg
def transform(self, tfm: Transform) -> None:
"""
In-place transform all attributes of this class.
By "in-place", it means after calling this method, accessing an attribute such
as ``self.image`` will return transformed data.
"""
self.image = tfm.apply_image(self.image)
if self.boxes is not None:
self.boxes = tfm.apply_box(self.boxes)
if self.sem_seg is not None:
self.sem_seg = tfm.apply_segmentation(self.sem_seg)
def apply_augmentations(
self, augmentations: List[Union[Augmentation, Transform]]
) -> TransformList:
"""
Equivalent of ``AugmentationList(augmentations)(self)``
"""
return AugmentationList(augmentations)(self)
class Tensor2Array(Transform):
""" Convert image tensor (CHW) to np array (HWC) """
def __init__(self):
super().__init__()
def apply_image(self, img: torch.Tensor) -> np.ndarray:
# CHW -> HWC
assert isinstance(img, torch.Tensor)
assert len(img.shape) == 3, img.shape
return img.cpu().numpy().transpose(1, 2, 0)
def apply_coords(self, coords):
return coords
def apply_segmentation(self, segmentation: torch.Tensor) -> np.ndarray:
assert len(segmentation.shape) == 2, segmentation.shape
return segmentation.cpu().numpy()
def inverse(self):
return Array2Tensor()
class Array2Tensor(Transform):
""" Convert image np array (HWC) to torch tensor (CHW) """
def __init__(self):
super().__init__()
def apply_image(self, img: np.ndarray) -> torch.Tensor:
# HWC -> CHW
assert isinstance(img, np.ndarray)
assert len(img.shape) == 3, img.shape
return torch.from_numpy(img.transpose(2, 0, 1).astype("float32"))
def apply_coords(self, coords):
return coords
def apply_segmentation(self, segmentation: np.ndarray) -> torch.Tensor:
assert len(segmentation.shape) == 2, segmentation.shape
return torch.from_numpy(segmentation.astype("long"))
def inverse(self):
return Tensor2Array()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment