Commit 1345fab2 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1263 canceled with stages
from typing import Any
import torch
from util.datapoints import Datapoint
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint)
from __future__ import annotations
from typing import Any, Callable, List, Tuple, Type, Union, Sequence
import PIL.Image
from util import datapoints
from transforms.v2.functional import get_dimensions, get_spatial_size, is_simple_tensor
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1:
return f"'{seq[0]}'"
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
return head + tail
def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox:
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBox)]
if not bounding_boxes:
raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 1:
raise ValueError("Found multiple bounding boxes in the sample")
return bounding_boxes.pop()
def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if isinstance(inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video)) or is_simple_tensor(inpt)
}
if not chws:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
c, h, w = chws.pop()
return c, h, w
def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = {
tuple(get_spatial_size(inpt))
for inpt in flat_inputs
if isinstance(
inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video, datapoints.Mask, datapoints.BoundingBox)
)
or is_simple_tensor(inpt)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(sizes) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
h, w = sizes.pop()
return h, w
def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False
def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for inpt in flat_inputs:
if check_type(inpt, types_or_checks):
return True
return False
def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for type_or_check in types_or_checks:
for inpt in flat_inputs:
if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break
else:
return False
return True
import copy
import json
from collections import ChainMap, OrderedDict, defaultdict
import numpy as np
import pycocotools.mask as mask_util
import torch
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import util.utils as utils
class CocoEvaluator(object):
def __init__(self, coco_gt, iou_types):
assert isinstance(iou_types, (list, tuple))
coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt
self.iou_types = iou_types
self.coco_eval = {}
for iou_type in iou_types:
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
self.img_ids = []
self.eval_imgs = {k: [] for k in iou_types}
self.predictions = {k: {} for k in iou_types}
def update(self, predictions):
img_ids = list(np.unique(list(predictions.keys())))
self.img_ids.extend(img_ids)
for iou_type in self.iou_types:
results = self.prepare(predictions, iou_type)
self.predictions[iou_type].update(predictions)
coco_dt = loadRes(self.coco_gt, results) if results else COCO()
coco_eval = self.coco_eval[iou_type]
coco_eval.cocoDt = coco_dt
coco_eval.params.imgIds = list(img_ids)
img_ids, eval_imgs = evaluate(coco_eval)
self.eval_imgs[iou_type].append(eval_imgs)
def synchronize_between_processes(self):
for iou_type in self.iou_types:
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
self.predictions[iou_type] = utils.all_gather(self.predictions[iou_type])
self.predictions[iou_type] = dict(ChainMap(*self.predictions[iou_type]))
def accumulate(self):
for coco_eval in self.coco_eval.values():
coco_eval.accumulate()
for iou_type in self.iou_types:
sorted_key = sorted(self.predictions[iou_type])
self.predictions[iou_type] = OrderedDict((k, self.predictions[iou_type][k]) for k in sorted_key)
self.predictions[iou_type] = self.prepare(self.predictions[iou_type], iou_type)
def summarize(self):
for iou_type, coco_eval in self.coco_eval.items():
print("IoU metric: {}".format(iou_type))
coco_eval.summarize()
def prepare(self, predictions, iou_type):
if iou_type == "bbox":
return self.prepare_for_coco_detection(predictions)
elif iou_type == "segm":
return self.prepare_for_coco_segmentation(predictions)
elif iou_type == "keypoints":
return self.prepare_for_coco_keypoint(predictions)
else:
raise ValueError("Unknown iou type {}".format(iou_type))
def prepare_for_coco_detection(self, predictions):
coco_results = []
for original_id, prediction in predictions.items():
if len(prediction) == 0:
continue
boxes = prediction["boxes"]
boxes = convert_to_xywh(boxes).tolist()
scores = prediction["scores"].tolist()
labels = prediction["labels"].tolist()
coco_results.extend([{
"image_id": original_id,
"category_id": labels[k],
"bbox": box,
"score": scores[k],
} for k, box in enumerate(boxes)])
return coco_results
def prepare_for_coco_segmentation(self, predictions):
coco_results = []
for original_id, prediction in predictions.items():
if len(prediction) == 0:
continue
scores = prediction["scores"]
labels = prediction["labels"]
masks = prediction["masks"]
masks = masks > 0.5
scores = prediction["scores"].tolist()
labels = prediction["labels"].tolist()
rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
coco_results.extend([{
"image_id": original_id,
"category_id": labels[k],
"segmentation": rle,
"score": scores[k],
} for k, rle in enumerate(rles)])
return coco_results
def prepare_for_coco_keypoint(self, predictions):
coco_results = []
for original_id, prediction in predictions.items():
if len(prediction) == 0:
continue
boxes = prediction["boxes"]
boxes = convert_to_xywh(boxes).tolist()
scores = prediction["scores"].tolist()
labels = prediction["labels"].tolist()
keypoints = prediction["keypoints"]
keypoints = keypoints.flatten(start_dim=1).tolist()
coco_results.extend([{
"image_id": original_id,
"category_id": labels[k],
"keypoints": keypoint,
"score": scores[k],
} for k, keypoint in enumerate(keypoints)])
return coco_results
def convert_to_xywh(boxes):
xmin, ymin, xmax, ymax = boxes.unbind(1)
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
def merge(img_ids, eval_imgs):
all_img_ids = utils.all_gather(img_ids)
all_eval_imgs = utils.all_gather(eval_imgs)
merged_img_ids = []
for p in all_img_ids:
merged_img_ids.extend(p)
merged_eval_imgs = []
for p in all_eval_imgs:
merged_eval_imgs.append(p)
merged_img_ids = np.array(merged_img_ids)
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
# keep only unique (and in sorted order) images
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
merged_eval_imgs = merged_eval_imgs[..., idx]
return merged_img_ids, merged_eval_imgs
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
img_ids, eval_imgs = merge(img_ids, eval_imgs)
img_ids = list(img_ids)
eval_imgs = list(eval_imgs.flatten())
coco_eval.evalImgs = eval_imgs
coco_eval.params.imgIds = img_ids
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
#################################################################
# From pycocotools, just removed the prints and fixed
# a Python3 bug about unicode not defined
#################################################################
# Ideally, pycocotools wouldn't have hard-coded prints
# so that we could avoid copy-pasting those two functions
def createIndex(self):
# create index
# print('creating index...')
anns, cats, imgs = {}, {}, {}
imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
if "annotations" in self.dataset:
for ann in self.dataset["annotations"]:
imgToAnns[ann["image_id"]].append(ann)
anns[ann["id"]] = ann
if "images" in self.dataset:
for img in self.dataset["images"]:
imgs[img["id"]] = img
if "categories" in self.dataset:
for cat in self.dataset["categories"]:
cats[cat["id"]] = cat
if "annotations" in self.dataset and "categories" in self.dataset:
for ann in self.dataset["annotations"]:
catToImgs[ann["category_id"]].append(ann["image_id"])
# print('index created!')
# create class members
self.anns = anns
self.imgToAnns = imgToAnns
self.catToImgs = catToImgs
self.imgs = imgs
self.cats = cats
maskUtils = mask_util
def loadRes(self, resFile):
"""
Load result file and return a result api object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
res = COCO()
res.dataset["images"] = [img for img in self.dataset["images"]]
# print('Loading and preparing results...')
# tic = time.time()
if isinstance(resFile, str):
anns = json.load(open(resFile))
elif type(resFile) == np.ndarray:
anns = self.loadNumpyAnnotations(resFile)
else:
anns = resFile
assert type(anns) == list, "results in not an array of objects"
annsImgIds = [ann["image_id"] for ann in anns]
assert set(annsImgIds) == (
set(annsImgIds) & set(self.getImgIds())
), "Results do not correspond to current coco set"
if "caption" in anns[0]:
imgIds = set([img["id"] for img in res.dataset["images"]]) & set([ann["image_id"] for ann in anns])
res.dataset["images"] = [img for img in res.dataset["images"] if img["id"] in imgIds]
for id, ann in enumerate(anns):
ann["id"] = id + 1
elif "bbox" in anns[0] and not anns[0]["bbox"] == []:
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
for id, ann in enumerate(anns):
bb = ann["bbox"]
x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
if "segmentation" not in ann:
ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
ann["area"] = bb[2] * bb[3]
ann["id"] = id + 1
ann["iscrowd"] = 0
elif "segmentation" in anns[0]:
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
for id, ann in enumerate(anns):
# now only support compressed RLE format as segmentation results
ann["area"] = maskUtils.area(ann["segmentation"])
if "bbox" not in ann:
ann["bbox"] = maskUtils.toBbox(ann["segmentation"])
ann["id"] = id + 1
ann["iscrowd"] = 0
elif "keypoints" in anns[0]:
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
for id, ann in enumerate(anns):
s = ann["keypoints"]
x = s[0::3]
y = s[1::3]
x1, x2, y1, y2 = np.min(x), np.max(x), np.min(y), np.max(y)
ann["area"] = (x2 - x1) * (y2 - y1)
ann["id"] = id + 1
ann["bbox"] = [x1, y1, x2 - x1, y2 - y1]
# print('DONE (t={:0.2f}s)'.format(time.time()- tic))
res.dataset["annotations"] = anns
createIndex(res)
return res
def evaluate(self):
"""
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
:return: None
"""
# tic = time.time()
# print('Running per image evaluation...')
p = self.params
# add backward compatibility if useSegm is specified in params
if p.useSegm is not None:
p.iouType = "segm" if p.useSegm == 1 else "bbox"
print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType))
# print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds))
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)
self.params = p
self._prepare()
# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]
if p.iouType == "segm" or p.iouType == "bbox":
computeIoU = self.computeIoU
elif p.iouType == "keypoints":
computeIoU = self.computeOks
self.ious = {(imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds}
evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
evalImgs = [
evaluateImg(imgId, catId, areaRng, maxDet)
for catId in catIds
for areaRng in p.areaRng
for imgId in p.imgIds
]
# this is NOT in the pycocotools code, but could be done outside
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
self._paramsEval = copy.deepcopy(self.params)
# toc = time.time()
# print('DONE (t={:0.2f}s).'.format(toc-tic))
return p.imgIds, evalImgs
#################################################################
# end of straight copy from pycocotools, just removing the prints
#################################################################
import copy
import torch
import torch.utils.data
import torchvision
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
from tqdm import tqdm
class FilterAndRemapCocoCategories(object):
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target
def convert_to_coco_api(ds):
coco_ds = COCO()
ann_id = 0
dataset = {"images": [], "categories": [], "annotations": []}
categories = set()
for img_idx in tqdm(range(len(ds))):
# find better way to get target
# targets = ds.get_annotations(img_idx)
img, targets = ds[img_idx]
image_id = targets["image_id"].item()
img_dict = {}
img_dict["id"] = image_id
img_dict["height"] = img.shape[-2]
img_dict["width"] = img.shape[-1]
dataset["images"].append(img_dict)
bboxes = targets["boxes"]
bboxes[:, 2:] -= bboxes[:, :2]
bboxes = bboxes.tolist()
labels = targets["labels"].tolist()
areas = targets["area"].tolist()
iscrowd = targets["iscrowd"].tolist()
if "masks" in targets:
masks = targets["masks"]
# make masks Fortran contiguous for coco_mask
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
if "keypoints" in targets:
keypoints = targets["keypoints"]
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
num_objs = len(bboxes)
for i in range(num_objs):
ann = {}
ann["image_id"] = image_id
ann["bbox"] = bboxes[i]
ann["category_id"] = labels[i]
categories.add(labels[i])
ann["area"] = areas[i]
ann["iscrowd"] = iscrowd[i]
ann["id"] = ann_id
if "masks" in targets:
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
if "keypoints" in targets:
ann["keypoints"] = keypoints[i]
ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
dataset["annotations"].append(ann)
ann_id += 1
dataset["categories"] = [{"id": i} for i in sorted(categories)]
coco_ds.dataset = dataset
coco_ds.createIndex()
return coco_ds
def get_coco_api_from_dataset(dataset):
for _ in range(10):
if isinstance(dataset, torchvision.datasets.CocoDetection):
break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco
return convert_to_coco_api(dataset)
import torch
from transforms import InterpolationMode
from transforms.simple_copy_paste import SimpleCopyPaste
from util.misc import to_device
def collate_fn(batch):
return tuple(zip(*batch))
def copypaste_collate_fn(batch):
copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR)
return copypaste(*collate_fn(batch))
class DataPrefetcher:
def __init__(self, loader, device):
self.loader = iter(loader)
self.device = device
if torch.cuda.is_available():
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
self.next_batch = next(self.loader)
except StopIteration:
self.next_batch = None
return
if torch.cuda.is_available():
with torch.cuda.stream(self.stream):
self.next_batch = to_device(self.next_batch, self.device)
else:
self.next_batch = to_device(self.next_batch, self.device)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
# self.next_input = self.next_input.float()
def next(self):
if torch.cuda.is_available():
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.next_batch
self.preload()
return batch
import importlib
import os
import re
import subprocess
import sys
from collections import defaultdict
import numpy as np
import PIL
import torch
import torchvision
from tabulate import tabulate
def collect_torch_env():
try:
import torch.__config__
return torch.__config__.show()
except ImportError:
# compatible with older versions of pytorch
from torch.utils.collect_env import get_pretty_env_info
return get_pretty_env_info()
def detect_compute_compatibility(CUDA_HOME, so_file):
try:
cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
if os.path.isfile(cuobjdump):
output = subprocess.check_output("'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True)
output = output.decode("utf-8").strip().split("\n")
arch = []
for line in output:
line = re.findall(r"\.sm_([0-9]*)\.", line)[0]
arch.append(".".join(line))
arch = sorted(set(arch))
return ", ".join(arch)
else:
return so_file + "; cannot find cuobjdump"
except Exception:
# unhandled failure
return so_file
def collect_env_info():
has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM
torch_version = torch.__version__
# NOTE that CUDA_HOME/ROCM_HOME could be None even when CUDA runtime libs are functional
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
has_rocm = False
if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None):
has_rocm = True
has_cuda = has_gpu and (not has_rocm)
data = []
data.append(("sys.platform", sys.platform)) # check-template.yml depends on it
data.append(("Python", sys.version.replace("\n", "")))
data.append(("numpy", np.__version__))
data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
data.append(("PyTorch debug build", torch.version.debug))
try:
data.append(("torch._C._GLIBCXX_USE_CXX11_ABI", torch._C._GLIBCXX_USE_CXX11_ABI))
except Exception:
pass
if not has_gpu:
has_gpu_text = "No: torch.cuda.is_available() == False"
else:
has_gpu_text = "Yes"
data.append(("GPU available", has_gpu_text))
if has_gpu:
devices = defaultdict(list)
for k in range(torch.cuda.device_count()):
cap = ".".join((str(x) for x in torch.cuda.get_device_capability(k)))
name = torch.cuda.get_device_name(k) + f" (arch={cap})"
devices[name].append(str(k))
for name, devids in devices.items():
data.append(("GPU " + ",".join(devids), name))
if has_rocm:
msg = " - invalid!" if not (ROCM_HOME and os.path.isdir(ROCM_HOME)) else ""
data.append(("ROCM_HOME", str(ROCM_HOME) + msg))
else:
try:
from torch.utils.collect_env import get_nvidia_driver_version
from torch.utils.collect_env import run as _run
data.append(("Driver version", get_nvidia_driver_version(_run)))
except Exception:
pass
msg = " - invalid!" if not (CUDA_HOME and os.path.isdir(CUDA_HOME)) else ""
data.append(("CUDA_HOME", str(CUDA_HOME) + msg))
cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if cuda_arch_list:
data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
data.append(("Pillow", PIL.__version__))
try:
data.append((
"torchvision",
str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
))
if has_cuda:
try:
torchvision_C = importlib.util.find_spec("torchvision._C").origin
msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
data.append(("torchvision arch flags", msg))
except (ImportError, AttributeError):
data.append(("torchvision._C", "Not found"))
except AttributeError:
data.append(("torchvision", "unknown"))
try:
import fvcore
data.append(("fvcore", fvcore.__version__))
except (ImportError, AttributeError):
pass
try:
import iopath
data.append(("iopath", iopath.__version__))
except (ImportError, AttributeError):
pass
try:
import cv2
data.append(("cv2", cv2.__version__))
except (ImportError, AttributeError):
data.append(("cv2", "Not found"))
env_str = tabulate(data) + "\n"
env_str += collect_torch_env()
return env_str
from __future__ import annotations
from enum import Enum
from types import ModuleType
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch
from torch._C import DisableTorchFunction
from transforms import InterpolationMode
class Datapoint(torch.Tensor):
__F: Optional[ModuleType] = None
@staticmethod
def _to_tensor(
data: Any,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> torch.Tensor:
if requires_grad is None:
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
raise NotImplementedError
_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output,
}
@classmethod
def __torch_function__(
cls,
func: Callable[..., torch.Tensor],
types: Tuple[Type[torch.Tensor], ...],
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
"""For general information about how the __torch_function__ protocol works,
see https://pytorch.org/docs/stable/notes/extending.html#extending-torch
TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
use case, this has two downsides:
1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
"""
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.
if not all(issubclass(cls, t) for t in types):
return NotImplemented
with DisableTorchFunction():
output = func(*args, **kwargs or dict())
wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
if wrapper and isinstance(args[0], cls):
return wrapper(cls, args[0], output)
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
# will retain the input type. Thus, we need to unwrap here.
if isinstance(output, cls):
return output.as_subclass(torch.Tensor)
return output
def _make_repr(self, **kwargs: Any) -> str:
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
# If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
return f"{super().__repr__()[:-1]}, {extra_repr})"
@property
def _F(self) -> ModuleType:
# This implements a lazy import of the functional to get around the cyclic import. This import is deferred
# until the first time we need reference to the functional module and it's shared across all instances of
# the class. This approach avoids the DataLoader issue described at
# https://github.com/pytorch/vision/pull/6476#discussion_r953588621
if Datapoint.__F is None:
from transforms.v2 import functional
Datapoint.__F = functional
return Datapoint.__F
@property
def data(self) -> torch.Tensor:
return self.as_subclass(torch.Tensor)
def horizontal_flip(self) -> Datapoint:
return self
def vertical_flip(self) -> Datapoint:
return self
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
# https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
return self
def crop(self, top: int, left: int, height: int, width: int) -> Datapoint:
return self
def center_crop(self, output_size: List[int]) -> Datapoint:
return self
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
return self
def pad(
self,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Datapoint:
return self
def rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Optional[List[float]] = None,
) -> Datapoint:
return self
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> Datapoint:
return self
def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
coefficients: Optional[List[float]] = None,
) -> Datapoint:
return self
def elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> Datapoint:
return self
def rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
return self
def adjust_brightness(self, brightness_factor: float) -> Datapoint:
return self
def adjust_saturation(self, saturation_factor: float) -> Datapoint:
return self
def adjust_contrast(self, contrast_factor: float) -> Datapoint:
return self
def adjust_sharpness(self, sharpness_factor: float) -> Datapoint:
return self
def adjust_hue(self, hue_factor: float) -> Datapoint:
return self
def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint:
return self
def posterize(self, bits: int) -> Datapoint:
return self
def solarize(self, threshold: float) -> Datapoint:
return self
def autocontrast(self) -> Datapoint:
return self
def equalize(self) -> Datapoint:
return self
def invert(self) -> Datapoint:
return self
def gaussian_blur(
self, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> Datapoint:
return self
class Image(Datapoint):
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> Image:
image = tensor.as_subclass(cls)
return image
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Image:
if isinstance(data, PIL.Image.Image):
from transforms import functional as F
data = F.pil_to_tensor(data)
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if tensor.ndim < 2:
raise ValueError
elif tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
return cls._wrap(tensor)
@classmethod
def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image:
return cls._wrap(tensor)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr()
@property
def spatial_size(self) -> Tuple[int, int]:
return tuple(self.shape[-2:]) # type: ignore[return-value]
@property
def num_channels(self) -> int:
return self.shape[-3]
def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def vertical_flip(self) -> Image:
output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
output = self._F.resize_image_tensor(
self.as_subclass(torch.Tensor),
size,
interpolation=interpolation,
max_size=max_size,
antialias=antialias,
)
return Image.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Image:
output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width)
return Image.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Image:
output = self._F.center_crop_image_tensor(
self.as_subclass(torch.Tensor), output_size=output_size
)
return Image.wrap_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
output = self._F.resized_crop_image_tensor(
self.as_subclass(torch.Tensor),
top,
left,
height,
width,
size=list(size),
interpolation=interpolation,
antialias=antialias,
)
return Image.wrap_like(self, output)
def pad(
self,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Image:
output = self._F.pad_image_tensor(
self.as_subclass(torch.Tensor),
padding,
fill=fill,
padding_mode=padding_mode,
)
return Image.wrap_like(self, output)
def rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Optional[List[float]] = None,
) -> Image:
output = self._F.rotate_image_tensor(
self.as_subclass(torch.Tensor),
angle,
interpolation=interpolation,
expand=expand,
fill=fill,
center=center,
)
return Image.wrap_like(self, output)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F.affine_image_tensor(
self.as_subclass(torch.Tensor),
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
return Image.wrap_like(self, output)
def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
coefficients: Optional[List[float]] = None,
) -> Image:
output = self._F.perspective_image_tensor(
self.as_subclass(torch.Tensor),
startpoints,
endpoints,
interpolation=interpolation,
fill=fill,
coefficients=coefficients,
)
return Image.wrap_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> Image:
output = self._F.elastic_image_tensor(
self.as_subclass(torch.Tensor),
displacement,
interpolation=interpolation,
fill=fill,
)
return Image.wrap_like(self, output)
def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image:
output = self._F.rgb_to_grayscale_image_tensor(
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
)
return Image.wrap_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Image:
output = self._F.adjust_brightness_image_tensor(
self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
)
return Image.wrap_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Image:
output = self._F.adjust_saturation_image_tensor(
self.as_subclass(torch.Tensor), saturation_factor=saturation_factor
)
return Image.wrap_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Image:
output = self._F.adjust_contrast_image_tensor(
self.as_subclass(torch.Tensor), contrast_factor=contrast_factor
)
return Image.wrap_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Image:
output = self._F.adjust_sharpness_image_tensor(
self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor
)
return Image.wrap_like(self, output)
def adjust_hue(self, hue_factor: float) -> Image:
output = self._F.adjust_hue_image_tensor(
self.as_subclass(torch.Tensor), hue_factor=hue_factor
)
return Image.wrap_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
output = self._F.adjust_gamma_image_tensor(
self.as_subclass(torch.Tensor), gamma=gamma, gain=gain
)
return Image.wrap_like(self, output)
def posterize(self, bits: int) -> Image:
output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits)
return Image.wrap_like(self, output)
def solarize(self, threshold: float) -> Image:
output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold)
return Image.wrap_like(self, output)
def autocontrast(self) -> Image:
output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def equalize(self) -> Image:
output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def invert(self) -> Image:
output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
output = self._F.gaussian_blur_image_tensor(
self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma
)
return Image.wrap_like(self, output)
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
output = self._F.normalize_image_tensor(
self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace
)
return Image.wrap_like(self, output)
class BoundingBoxFormat(Enum):
XYXY = "XYXY"
XYWH = "XYWH"
CXCYWH = "CXCYWH"
class BoundingBox(Datapoint):
format: BoundingBoxFormat
spatial_size: Tuple[int, int]
@classmethod
def _wrap(
cls,
tensor: torch.Tensor,
*,
format: BoundingBoxFormat,
spatial_size: Tuple[int, int],
) -> BoundingBox:
bounding_box = tensor.as_subclass(cls)
bounding_box.format = format
bounding_box.spatial_size = spatial_size
return bounding_box
def __new__(
cls,
data: Any,
*,
format: Union[BoundingBoxFormat, str],
spatial_size: Tuple[int, int],
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> BoundingBox:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]
return cls._wrap(tensor, format=format, spatial_size=spatial_size)
@classmethod
def wrap_like(
cls,
other: BoundingBox,
tensor: torch.Tensor,
*,
format: Optional[BoundingBoxFormat] = None,
spatial_size: Optional[Tuple[int, int]] = None,
) -> BoundingBox:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBox` from a reference.
Args:
other (BoundingBox): Reference bounding box.
tensor (Tensor): Tensor to be wrapped as :class:`BoundingBox`
format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the
reference.
spatial_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
omitted, it is taken from the reference.
"""
if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]
return cls._wrap(
tensor,
format=format if format is not None else other.format,
spatial_size=spatial_size if spatial_size is not None else other.spatial_size,
)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, spatial_size=self.spatial_size)
def horizontal_flip(self) -> BoundingBox:
output = self._F.horizontal_flip_bounding_box(
self.as_subclass(torch.Tensor),
format=self.format,
spatial_size=self.spatial_size,
)
return BoundingBox.wrap_like(self, output)
def vertical_flip(self) -> BoundingBox:
output = self._F.vertical_flip_bounding_box(
self.as_subclass(torch.Tensor),
format=self.format,
spatial_size=self.spatial_size,
)
return BoundingBox.wrap_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
output, spatial_size = self._F.resize_bounding_box(
self.as_subclass(torch.Tensor),
spatial_size=self.spatial_size,
size=size,
max_size=max_size,
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
output, spatial_size = self._F.crop_bounding_box(
self.as_subclass(torch.Tensor),
self.format,
top=top,
left=left,
height=height,
width=width,
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def center_crop(self, output_size: List[int]) -> BoundingBox:
output, spatial_size = self._F.center_crop_bounding_box(
self.as_subclass(torch.Tensor),
format=self.format,
spatial_size=self.spatial_size,
output_size=output_size,
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
output, spatial_size = self._F.resized_crop_bounding_box(
self.as_subclass(torch.Tensor),
self.format,
top,
left,
height,
width,
size=size,
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def pad(
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> BoundingBox:
output, spatial_size = self._F.pad_bounding_box(
self.as_subclass(torch.Tensor),
format=self.format,
spatial_size=self.spatial_size,
padding=padding,
padding_mode=padding_mode,
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Optional[List[float]] = None,
) -> BoundingBox:
output, spatial_size = self._F.rotate_bounding_box(
self.as_subclass(torch.Tensor),
format=self.format,
spatial_size=self.spatial_size,
angle=angle,
expand=expand,
center=center,
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.affine_bounding_box(
self.as_subclass(torch.Tensor),
self.format,
self.spatial_size,
angle,
translate=translate,
scale=scale,
shear=shear,
center=center,
)
return BoundingBox.wrap_like(self, output)
def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
coefficients: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.perspective_bounding_box(
self.as_subclass(torch.Tensor),
format=self.format,
spatial_size=self.spatial_size,
startpoints=startpoints,
endpoints=endpoints,
coefficients=coefficients,
)
return BoundingBox.wrap_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(
self.as_subclass(torch.Tensor),
self.format,
self.spatial_size,
displacement=displacement,
)
return BoundingBox.wrap_like(self, output)
class Mask(Datapoint):
"""[BETA] :class:`torch.Tensor` subclass for segmentation and detection masks.
Args:
data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
well as PIL images.
dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
``data``.
device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
:class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
"""
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> Mask:
return tensor.as_subclass(cls)
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Mask:
if isinstance(data, PIL.Image.Image):
from transforms.v2 import functional as F
data = F.pil_to_tensor(data)
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor)
@classmethod
def wrap_like(
cls,
other: Mask,
tensor: torch.Tensor,
) -> Mask:
return cls._wrap(tensor)
@property
def spatial_size(self) -> Tuple[int, int]:
return tuple(self.shape[-2:]) # type: ignore[return-value]
def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
return Mask.wrap_like(self, output)
def vertical_flip(self) -> Mask:
output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor))
return Mask.wrap_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size)
return Mask.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Mask:
output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width)
return Mask.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Mask:
output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size)
return Mask.wrap_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
output = self._F.resized_crop_mask(
self.as_subclass(torch.Tensor), top, left, height, width, size=size
)
return Mask.wrap_like(self, output)
def pad(
self,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Mask:
output = self._F.pad_mask(
self.as_subclass(torch.Tensor),
padding,
padding_mode=padding_mode,
fill=fill,
)
return Mask.wrap_like(self, output)
def rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Optional[List[float]] = None,
) -> Mask:
output = self._F.rotate_mask(
self.as_subclass(torch.Tensor),
angle,
expand=expand,
center=center,
fill=fill,
)
return Mask.wrap_like(self, output)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.affine_mask(
self.as_subclass(torch.Tensor),
angle,
translate=translate,
scale=scale,
shear=shear,
fill=fill,
center=center,
)
return Mask.wrap_like(self, output)
def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
coefficients: Optional[List[float]] = None,
) -> Mask:
output = self._F.perspective_mask(
self.as_subclass(torch.Tensor),
startpoints,
endpoints,
fill=fill,
coefficients=coefficients,
)
return Mask.wrap_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> Mask:
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
return Mask.wrap_like(self, output)
class Video(Datapoint):
"""[BETA] :class:`torch.Tensor` subclass for videos.
Args:
data (tensor-like): Any data that can be turned into a tensor with :func:`torch.as_tensor`.
dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
``data``.
device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
:class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
"""
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> Video:
video = tensor.as_subclass(cls)
return video
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Video:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if data.ndim < 4:
raise ValueError
return cls._wrap(tensor)
@classmethod
def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video:
return cls._wrap(tensor)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr()
@property
def spatial_size(self) -> Tuple[int, int]:
return tuple(self.shape[-2:]) # type: ignore[return-value]
@property
def num_channels(self) -> int:
return self.shape[-3]
@property
def num_frames(self) -> int:
return self.shape[-4]
_ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
_ImageTypeJIT = torch.Tensor
_TensorImageType = Union[torch.Tensor, Image]
_TensorImageTypeJIT = torch.Tensor
_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
_InputTypeJIT = torch.Tensor
_VideoType = Union[torch.Tensor, Video]
_VideoTypeJIT = torch.Tensor
_TensorVideoType = Union[torch.Tensor, Video]
_TensorVideoTypeJIT = torch.Tensor
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]
import contextlib
import datetime
import io
import logging
import math
import os
import sys
import time
import torch
from terminaltables import AsciiTable
import util.utils as utils
from util.coco_eval import CocoEvaluator
from util.coco_utils import get_coco_api_from_dataset
from util.collate_fn import DataPrefetcher
def train_one_epoch_acc(
model, optimizer, data_loader, epoch, print_freq=50, max_grad_norm=-1, accelerator=None
):
logger = logging.getLogger(os.path.basename(os.getcwd()) + "." + __name__)
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
metric_logger.add_meter("data_time", utils.SmoothedValue(fmt="{avg:.4f}"))
metric_logger.add_meter("iter_time", utils.SmoothedValue(fmt="{avg:.4f}"))
lr_scheduler = None
if epoch == 0:
warmup_factor = 1.0 / 1000
warmup_iters = min(1000, len(data_loader) - 1)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer=optimizer, start_factor=warmup_factor, total_iters=warmup_iters
)
prefetcher = DataPrefetcher(data_loader, accelerator.device)
next_data_time = None
data_start_time = time.perf_counter()
images, targets = prefetcher.next()
data_time = time.perf_counter() - data_start_time
iter_start_time = time.perf_counter()
for i in range(len(data_loader)):
with accelerator.accumulate(model):
# model forward
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# prefetch next batch data
data_time = next_data_time if i > 0 else data_time
if i < len(data_loader) - 1:
data_start_time = time.perf_counter()
images, targets = prefetcher.next()
next_data_time = time.perf_counter() - data_start_time
# backward propagation
optimizer.zero_grad()
accelerator.backward(losses)
if accelerator.sync_gradients and max_grad_norm > 0:
accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
if epoch == 0:
lr_scheduler.step()
# reduce losses over all GPUs for logging purposes
with torch.no_grad():
loss_dict_reduced = accelerator.reduce(loss_dict, reduction="mean")
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
loss_value = losses_reduced.item()
if not math.isfinite(loss_value):
logger.warning(f"Loss is {loss_value}, stopping training")
logger.warning(loss_dict_reduced)
sys.exit(1)
# collect logging messages
training_logs = {"loss": losses_reduced.item(), **loss_dict_reduced}
training_logs.update({"lr": optimizer.param_groups[0]["lr"]})
metric_logger.update(**training_logs)
# update iter_time and data_time
iter_time = time.perf_counter() - iter_start_time
iter_start_time = time.perf_counter()
metric_logger.update(**{"iter_time": iter_time, "data_time": data_time})
# logging track
if i % print_freq == 0:
logger.info(get_logging_string(metric_logger, data_loader, i, epoch))
training_logs = {k.replace("loss_", "loss/"): v for k, v in training_logs.items()}
accelerator.log(training_logs, step=i + len(data_loader) * epoch)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
logger.info(f"Averaged stats: {metric_logger}")
return metric_logger
@torch.no_grad()
def evaluate_acc(model, data_loader, epoch, accelerator=None):
logger = logging.getLogger(os.path.basename(os.getcwd()) + "." + __name__)
# evaluation uses single thread
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
coco = get_coco_api_from_dataset(data_loader.dataset)
coco_evaluator = CocoEvaluator(coco, ["bbox"])
# for collect detection numbers
category_det_nums = [0] * (max(coco.getCatIds()) + 1)
for images, targets in metric_logger.log_every(data_loader, 10, header):
# get model predictions
model_time = time.time()
outputs = model(images)
# non_blocking=True here causes incorrect performance
outputs = [{k: v.to("cpu") for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time
# perform evaluation through COCO API
res = {target["image_id"]: output for target, output in zip(targets, outputs)}
evaluator_time = time.time()
coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time
metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
# update detection number
cat_names = [cat["name"] for cat in coco.loadCats(coco.getCatIds())]
for cat_name in cat_names:
cat_id = coco.getCatIds(catNms=cat_name)
cat_det_num = len(coco_evaluator.coco_eval["bbox"].cocoDt.getAnnIds(catIds=cat_id))
category_det_nums[cat_id[0]] += cat_det_num
# gather the stats from all processes
metric_logger.synchronize_between_processes()
logger.info(f"Averaged stats: {metric_logger}")
coco_evaluator.synchronize_between_processes()
# accumulate predictions from all images
redirect_string = io.StringIO()
with contextlib.redirect_stdout(redirect_string):
coco_evaluator.accumulate()
coco_evaluator.summarize()
logger.info(redirect_string.getvalue())
# print category-wise evaluation results
cat_names = [cat["name"] for cat in coco.loadCats(coco.getCatIds())]
table_data = [["class", "imgs", "gts", "dets", "recall", "ap"]]
# table data for show, each line has the number of image, annotations, detections and metrics
bbox_coco_eval = coco_evaluator.coco_eval["bbox"]
for cat_idx, cat_name in enumerate(cat_names):
cat_id = coco.getCatIds(catNms=cat_name)
num_img_id = len(coco.getImgIds(catIds=cat_id))
num_ann_id = len(coco.getAnnIds(catIds=cat_id))
row_data = [cat_name, num_img_id, num_ann_id, category_det_nums[cat_id[0]]]
row_data += [f"{bbox_coco_eval.eval['recall'][0, cat_idx, 0, 2].item():.3f}"]
row_data += [f"{bbox_coco_eval.eval['precision'][0, :, cat_idx, 0, 2].mean().item():.3f}"]
table_data.append(row_data)
# get the final line of mean results
cat_recall = coco_evaluator.coco_eval["bbox"].eval["recall"][0, :, 0, 2]
valid_cat_recall = cat_recall[cat_recall >= 0]
mean_recall = valid_cat_recall.sum() / max(len(valid_cat_recall), 1)
cat_ap = coco_evaluator.coco_eval["bbox"].eval["precision"][0, :, :, 0, 2]
valid_cat_ap = cat_ap[cat_ap >= 0]
mean_ap50 = valid_cat_ap.sum() / max(len(valid_cat_ap), 1)
mean_data = ["mean results", "", "", "", f"{mean_recall:.3f}", f"{mean_ap50:.3f}"]
table_data.append(mean_data)
# show results
table = AsciiTable(table_data)
table.inner_footing_row_border = True
logger.info("\n" + table.table)
metric_names = ["mAP", "AP@50", "AP@75", "AP-s", "AP-m", "AP-l"]
metric_names += ["AR_1", "AR_10", "AR_100", "AR-s", "AR-m", "AR-l"]
metric_dict = dict(zip(metric_names, coco_evaluator.coco_eval["bbox"].stats))
accelerator.log({f"val/{k}": v for k, v in metric_dict.items()}, step=epoch)
return coco_evaluator
def get_logging_string(metric_logger, data_loader, i, epoch):
MB = 1024 * 1024
eta_seconds = metric_logger.meters["iter_time"].global_avg * (len(data_loader) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
memory = torch.cuda.memory_allocated() / MB
max_memory = torch.cuda.max_memory_allocated() / MB
log_msg = f"Epoch: [{epoch}] [{i}/{len(data_loader)}] eta: {eta_string} "
log_msg += f"{str(metric_logger)} mem: {memory:.0f} max mem: {max_memory:.0f}"
return log_msg
import bisect
import copy
from collections import defaultdict
import numpy as np
import torch
import torch.utils.data
import torchvision
from PIL import Image
from torch.utils.data.sampler import BatchSampler, Sampler
from torch.utils.model_zoo import tqdm
class GroupedBatchSampler(BatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices.
It enforces that the batch only contain elements from the same group.
It also tries to provide mini-batches which follows an ordering which is
as close as possible to the ordering from the original sampler.
Arguments:
sampler (Sampler): Base sampler.
group_ids (list[int]): If the sampler produces indices in range [0, N),
`group_ids` must be a list of `N` ints which contains the group id of each sample.
The group ids must be a continuous set of integers starting from
0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch.
"""
def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = group_ids
self.batch_size = batch_size
def __iter__(self):
buffer_per_group = defaultdict(list)
samples_per_group = defaultdict(list)
num_batches = 0
for idx in self.sampler:
group_id = self.group_ids[idx]
buffer_per_group[group_id].append(idx)
samples_per_group[group_id].append(idx)
if len(buffer_per_group[group_id]) == self.batch_size:
yield buffer_per_group[group_id]
num_batches += 1
del buffer_per_group[group_id]
assert len(buffer_per_group[group_id]) < self.batch_size
# now we have run out of elements that satisfy
# the group criteria, let's return the remaining
# elements so that the size of the sampler is
# deterministic
expected_num_batches = len(self)
num_remaining = expected_num_batches - num_batches
if num_remaining > 0:
# for the remaining batches, take first the buffers with largest number
# of elements
for group_id, _ in sorted(
buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True
):
remaining = self.batch_size - len(buffer_per_group[group_id])
buffer_per_group[group_id].extend(samples_per_group[group_id][:remaining])
assert len(buffer_per_group[group_id]) == self.batch_size
yield buffer_per_group[group_id]
num_remaining -= 1
if num_remaining == 0:
break
assert num_remaining == 0
def __len__(self):
return len(self.sampler) // self.batch_size
def _compute_aspect_ratios_slow(dataset, indices=None):
print(
"Your dataset doesn't support the fast path for "
"computing the aspect ratios, so will iterate over "
"the full dataset and load every image instead. "
"This might take some time..."
)
if indices is None:
indices = range(len(dataset))
class SubsetSampler(Sampler):
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
sampler = SubsetSampler(indices)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
sampler=sampler,
num_workers=14, # you might want to increase it for faster processing
collate_fn=lambda x: x[0],
)
aspect_ratios = []
with tqdm(total=len(dataset)) as pbar:
for _i, (img, _) in enumerate(data_loader):
pbar.update(1)
height, width = img.shape[-2:]
aspect_ratio = float(width) / float(height)
aspect_ratios.append(aspect_ratio)
return aspect_ratios
def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
if indices is None:
indices = range(len(dataset))
aspect_ratios = []
for i in indices:
height, width = dataset.get_height_and_width(i)
aspect_ratio = float(width) / float(height)
aspect_ratios.append(aspect_ratio)
return aspect_ratios
def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
if indices is None:
indices = range(len(dataset))
aspect_ratios = []
for i in indices:
img_info = dataset.coco.imgs[dataset.ids[i]]
aspect_ratio = float(img_info["width"]) / float(img_info["height"])
aspect_ratios.append(aspect_ratio)
return aspect_ratios
def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
if indices is None:
indices = range(len(dataset))
aspect_ratios = []
for i in indices:
# this doesn't load the data into memory, because PIL loads it lazily
width, height = Image.open(dataset.images[i]).size
aspect_ratio = float(width) / float(height)
aspect_ratios.append(aspect_ratio)
return aspect_ratios
def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
if indices is None:
indices = range(len(dataset))
ds_indices = [dataset.indices[i] for i in indices]
return compute_aspect_ratios(dataset.dataset, ds_indices)
def compute_aspect_ratios(dataset, indices=None):
if hasattr(dataset, "get_height_and_width"):
return _compute_aspect_ratios_custom_dataset(dataset, indices)
if isinstance(dataset, torchvision.datasets.CocoDetection):
return _compute_aspect_ratios_coco_dataset(dataset, indices)
if isinstance(dataset, torchvision.datasets.VOCDetection):
return _compute_aspect_ratios_voc_dataset(dataset, indices)
if isinstance(dataset, torch.utils.data.Subset):
return _compute_aspect_ratios_subset_dataset(dataset, indices)
# slow path
return _compute_aspect_ratios_slow(dataset, indices)
def _quantize(x, bins):
bins = copy.deepcopy(bins)
bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized
def create_aspect_ratio_groups(dataset, k=0):
aspect_ratios = compute_aspect_ratios(dataset)
bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
groups = _quantize(aspect_ratios, bins)
# count number of elements per group
counts = np.unique(groups, return_counts=True)[1]
fbins = [0] + bins + [np.inf]
print("Using {} as bins for aspect ratio quantization".format(fbins))
print("Count of instances per bin: {}".format(counts))
return groups
import ast
import dataclasses
import inspect
import logging
import pydoc
from collections import abc
from typing import Any, List
from omegaconf import DictConfig
try:
from ast import unparse
except ImportError:
from astunparse import unparse
class Config:
def __init__(self, file_path, name_space={}, partials=()):
self.partials = partials
with open(file_path, "r") as f:
code = f.read()
if len(partials) != 0:
code = self.partial_optim(code)
exec(code, name_space)
self.__dict__ = {k: v for k, v in name_space.items() if k != "__builtins__"}
def partial_optim(self, code):
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
if type(node.value) == ast.Call:
assign_target = unparse(node.targets[0]).rstrip("\n")
variant = assign_target.replace("'", '"')
if assign_target in self.partials or variant in self.partials:
node.value = ast.Call(
func=ast.Name(id="partial", ctx=ast.Load()),
args=[node.value.func] + node.value.args,
keywords=[] + node.value.keywords,
)
ast_string = "from functools import partial\n" + unparse(tree)
return ast_string
class LazyConfig:
def __init__(self, file_path, name_space={}, lazy={}):
self.lazy = lazy
with open(file_path, "r") as f:
code = f.read()
if len(self.lazy) != 0:
code = self.replace_call_with_lazy_call(code)
exec(code, name_space)
self.__dict__ = {k: v for k, v in name_space.items() if k != "__builtins__"}
def replace_call_with_lazy_call(self, code):
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
if type(node.value) == ast.Call:
assign_target = unparse(node.targets[0]).rstrip("\n")
variant = assign_target.replace("'", '"')
if assign_target in self.lazy or variant in self.lazy:
node.value = ast.Call(
func=ast.Call(
func=ast.Name(id="L", ctx=ast.Load()),
args=[node.value.func],
keywords=[],
),
args=node.value.args,
keywords=node.value.keywords,
)
ast_string = "from util.lazy_load import LazyCall as L\n" + unparse(tree)
return ast_string
def is_dataclass(obj):
"""Returns True if obj is a dataclass or an instance of a
dataclass."""
cls = obj if isinstance(obj, type) and not isinstance(obj, type(List[int])) else type(obj)
return hasattr(cls, "__dataclass_fields__")
def locate(name: str) -> Any:
"""
Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
such as "module.submodule.class_name".
Raise Exception if it cannot be found.
"""
obj = pydoc.locate(name)
# Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
# by pydoc.locate. Try a private function from hydra.
if obj is None:
try:
# from hydra.utils import get_method - will print many errors
from hydra.utils import _locate
except ImportError as e:
raise ImportError(f"Cannot dynamically locate object {name}!") from e
else:
obj = _locate(name) # it raises if fails
return obj
def _convert_target_to_string(t: Any) -> str:
"""
Inverse of ``locate()``.
Args:
t: any object with ``__module__`` and ``__qualname__``
"""
module, qualname = t.__module__, t.__qualname__
# Compress the path to this object, e.g. ``module.submodule._impl.class``
# may become ``module.submodule.class``, if the later also resolves to the same
# object. This simplifies the string, and also is less affected by moving the
# class implementation.
module_parts = module.split(".")
for k in range(1, len(module_parts)):
prefix = ".".join(module_parts[:k])
candidate = f"{prefix}.{qualname}"
try:
if locate(candidate) is t:
return candidate
except ImportError:
pass
return f"{module}.{qualname}"
class LazyCall:
"""
Wrap a callable so that when it's called, the call will not be executed,
but returns a dict that describes the call.
LazyCall object has to be called with only keyword arguments. Positional
arguments are not yet supported.
Examples:
::
from detectron2.config import instantiate, LazyCall
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
layer_cfg.out_channels = 64 # can edit it afterwards
layer = instantiate(layer_cfg)
"""
def __init__(self, target):
if not (callable(target) or isinstance(target, (str, abc.Mapping))):
raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}")
self._target = target
def __call__(self, *args, **kwargs):
if is_dataclass(self._target):
# omegaconf object cannot hold dataclass type
# https://github.com/omry/omegaconf/issues/784
target = _convert_target_to_string(self._target)
else:
target = self._target
variable_args, arg_kwargs = self.transfer_args_into_kwargs(args)
kwargs.update(arg_kwargs)
kwargs["_target_"] = target
kwargs["_variable_args_"] = variable_args
return DictConfig(content=kwargs, flags={"allow_objects": True})
def transfer_args_into_kwargs(self, args):
kwargs = {}
variable_args = None
params = inspect.signature(self._target).parameters
for arg_ind, (name, param) in enumerate(params.items()):
if arg_ind >= len(args):
break
if param.kind == inspect._ParameterKind.VAR_POSITIONAL:
variable_args = args[arg_ind:]
break
else:
kwargs[name] = args[arg_ind]
return variable_args, kwargs
def instantiate(cfg):
"""
Recursively instantiate objects defined in dictionaries by
"_target_" and arguments.
Args:
cfg: a dict-like object with "_target_" that defines the caller, and
other keys that define the arguments
Returns:
object instantiated by cfg
"""
from omegaconf import DictConfig, ListConfig, OmegaConf
if isinstance(cfg, ListConfig):
lst = [instantiate(x) for x in cfg]
return ListConfig(lst, flags={"allow_objects": True})
if isinstance(cfg, list):
# Specialize for list, because many classes take
# list[objects] as arguments, such as ResNet, DatasetMapper
return [instantiate(x) for x in cfg]
# If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
# instantiate it to the actual dataclass.
if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type):
return OmegaConf.to_object(cfg)
if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
# conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
# but faster: https://github.com/facebookresearch/hydra/issues/1200
cfg = {k: instantiate(v) for k, v in cfg.items()}
cls = cfg.pop("_target_")
variable_args = cfg.pop("_variable_args_")
cls = instantiate(cls)
if isinstance(cls, str):
cls_name = cls
cls = locate(cls_name)
assert cls is not None, cls_name
else:
try:
cls_name = cls.__module__ + "." + cls.__qualname__
except Exception:
# target could be anything, so the above could fail
cls_name = str(cls)
assert callable(cls), f"_target_ {cls} does not define a callable object"
try:
# split args from kwargs and instantiate cls with normal sequence:
# args, variable_args, kwargs
if variable_args is not None:
params = inspect.signature(cls).parameters
try:
p_kind_list = [p.kind for p in params.values()]
i = p_kind_list.index(inspect._ParameterKind.VAR_POSITIONAL)
except ValueError:
i = None
arg_keys = list(params.keys())[:i]
args = []
for key in arg_keys:
args.append(cfg.pop(key))
if variable_args is not None:
args.extend(variable_args)
return cls(*args, **cfg)
else:
return cls(**cfg)
except TypeError:
import os
logger = logging.getLogger(os.path.basename(os.getcwd()) + "." + __name__)
logger.error(f"Error when instantiating {cls_name}!")
raise
return cfg # return as-is if don't know what to do
import atexit
import functools
import logging
import os
import re
import sys
from accelerate.logging import get_logger
from fvcore.common.file_io import PathManager
from termcolor import colored
def create_logger(output_dir=None, dist_rank=0):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.propagate = False
fmt = "[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s"
color_fmt = colored("[%(asctime)s %(name)s]", "green")
color_fmt += colored("(%(filename)s %(lineno)d)", "yellow")
color_fmt += ": %(levelname)s %(message)s"
# create console handlers for master process
if dist_rank == 0:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(logging.Formatter(fmt=color_fmt, datefmt="%Y-%m-%d %H:%M:%S"))
logger.addHandler(console_handler)
# create file handlers
if output_dir:
file_handler = logging.FileHandler(os.path.join(output_dir, "training.log"), mode="a")
file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S"))
logger.addHandler(file_handler)
return logger
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
class ColorFilter(logging.Filter):
def filter(self, record):
message = record.getMessage()
# matching colored patterns
pattern = re.compile(r'\x1b\[[0-9;]*m')
if pattern.search(message):
record.msg = pattern.sub('', message)
return True
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
def setup_logger(
output=None,
distributed_rank=0,
*,
color=True,
name="detection",
abbrev_name=None,
enable_propagation: bool = False,
configure_stdout: bool = True,
):
"""Initialize the detection logger and set its verbosity level to "DEBUG"
:param output: a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name, defaults to None
:param distributed_rank: rank number id in distributed training, defaults to 0
:param color: whether to show colored logging information, defaults to True
:param name: the root module name of this logger, defaults to "detection"
:param abbrev_name: an abbreviation of the module, to avoid long names in logs.
Set to "" to not log the root module in logs. By default, will abbreviate "detection"
to "det" and leave other modules unchanged, defaults to None
:param enable_propagation: whether to propogate logs to the parent logger, defaults to False
:param configure_stdout: whether to configure logging to stdout, defaults to True
"""
logger_adapter = get_logger(name, "DEBUG")
logger = logger_adapter.logger
logger.propagate = enable_propagation
if abbrev_name is None:
abbrev_name = name.replace(os.path.basename(os.getcwd()), "det")
plain_formatter = logging.Formatter(
"[%(asctime)s %(name)s] %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
# stdout logging: master only
if configure_stdout and distributed_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
root_name=name,
abbrev_name=str(abbrev_name),
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.log")
if distributed_rank > 0:
filename = filename.replace(".", "_rank{}".format(distributed_rank) + ".")
os.makedirs(os.path.dirname(filename), exist_ok=True)
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.addFilter(ColorFilter())
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger_adapter
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
# use 1K buffer if writing to cloud storage
io = PathManager.open(filename, "a", buffering=_get_log_stream_buffer_size(filename))
atexit.register(io.close)
return io
def _get_log_stream_buffer_size(filename: str) -> int:
if "://" not in filename:
# Local file, no extra caching is necessary
return -1
# Remote file requires a larger cache to avoid many small writes.
return 1024 * 1024
import copy
import functools
import logging
import math
import os
import random
from datetime import datetime
from functools import partial
from typing import List
import accelerate
import numpy as np
import torch
import torchvision
from accelerate.logging import get_logger
from fvcore.common.file_io import PathManager
from torch import Tensor
from torchvision.models.detection.image_list import ImageList
from util import utils
from util.collect_env import collect_env_info
from util.logger import setup_logger
def replace_prefix(string, prefix, replacement):
if string.startswith(prefix):
string = replacement + string[len(prefix):]
return string
def inverse_sigmoid(x, eps: float = 1e-3):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def get_total_grad_norm(parameters, norm_type=2):
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
device = parameters[0].device
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
norm_type,
)
return total_norm
# _onnx_batch_images() is an implementation of
# batch_images() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_batch_images(images: List[Tensor], size_divisible: int = 32) -> Tensor:
max_size = []
for i in range(images[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i]
for img in images]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
stride = size_divisible
max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# which is not yet supported in onnx
padded_imgs = []
for img in images:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
return torch.stack(padded_imgs)
def image_list_from_tensors(images: List[Tensor], size_divisible=32, fill_value=0):
# check channels of images
batched_channel = images[0].shape[0]
assert all(
batched_channel == image.shape[0] for image in images
), f"all images must have the same channel but got {list(map(lambda x: x.shape[0], images))}"
# get original_shapes and batched_shape
original_shapes = list(map(lambda x: x.shape[-2:], images))
# get batched shapes, divisible by size_divisible
if torchvision._is_tracing():
# batch_images() does not export well to ONNX
# call _onnx_batch_images() instead
batched_images = _onnx_batch_images(images, size_divisible)
return ImageList(batched_images, original_shapes)
original_h, original_w = list(zip(*original_shapes))
batched_h, batched_w = max(original_h), max(original_w)
batched_h = int(math.ceil(float(batched_h) / size_divisible) * size_divisible)
batched_w = int(math.ceil(float(batched_w) / size_divisible) * size_divisible)
# generate batched image tensors
batched_shape = (len(images), batched_channel, batched_h, batched_w)
batched_images = images[0].new_full(batched_shape, fill_value)
for idx, image in enumerate(images):
batched_images[idx, :, :image.shape[1], :image.shape[2]].copy_(image)
batched_images = ImageList(batched_images, original_shapes)
return batched_images
def _highlight(code, filename):
try:
import pygments
except ImportError:
return code
from pygments.formatters import Terminal256Formatter
from pygments.lexers import Python3Lexer, YamlLexer
lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
return code
def default_setup(args, cfg, accelerator):
output_dir = getattr(cfg, "output_dir", None)
rank = accelerator.local_process_index
if accelerator.is_main_process and output_dir:
os.makedirs(output_dir, exist_ok=True)
# capture warning.warns information into logging
logging.captureWarnings(True)
train_log_file = os.path.join(output_dir, "training.log")
set_logger = partial(setup_logger, output=train_log_file, distributed_rank=rank)
# setup loggers from warnings, accelerate, detection framworks
root_logger_name = os.path.basename(os.getcwd())
list(map(lambda x: set_logger(name=x), ["py.warnings", "accelerate", root_logger_name]))
logger = get_logger(root_logger_name + "." + __name__)
logger.info("Rank of current process: {}, World size: {}".format(rank, utils.get_world_size()))
logger.info("Environment info: \n" + collect_env_info())
logger.info("Command line arguments: " + str(args))
if hasattr(args, "config_file") and args.config_file != "":
logger.info(
"Contents of args.config_file={}:\n{}".format(
args.config_file,
_highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
)
)
# make sure each worker has a different, yet deterministic seed if specified
if hasattr(args, "seed") and args.seed and args.seed > 0:
seed = args.seed
else:
seed = (os.getpid() + int(datetime.now().strftime("%S%f")) + int.from_bytes(os.urandom(2), "big"))
logger.info("Using the random seed: {}".format(seed))
accelerate.utils.set_seed(seed, device_specific=True)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def fixed_generator():
g = torch.Generator()
g.manual_seed(0)
return g
def to_device(inputs, device):
if isinstance(inputs, (tuple, list)):
return type(inputs)([to_device(i, device) for i in inputs])
if isinstance(inputs, dict):
return type(inputs)({to_device(k, device): to_device(v, device) for k, v in inputs.items()})
if isinstance(inputs, torch.Tensor):
return inputs.to(device, non_blocking=True)
return inputs
def deepcopy(inputs):
if isinstance(inputs, (tuple, list)):
return type(inputs)([deepcopy(i) for i in inputs])
if isinstance(inputs, dict):
return type(inputs)({deepcopy(k): deepcopy(v) for k, v in inputs.items()})
if isinstance(inputs, torch.Tensor):
return inputs.clone().detach()
return copy.deepcopy(inputs)
@functools.lru_cache
def encode_labels(labels: List[str]):
"""Encode a list of string to a list of int, for example: ["l1", "Label2", "n"]
will be encoded as: [108, 49, -1, 76, 97, 98, 101, 108, 50, -1, 110, -1].
Each letter will be converted using ord() function in Python.
:param labels: A list of str to be encoded.
:return: A list of int, in which -1 is used as delimiters to split strings.
"""
assert [isinstance(s, str) for s in labels], "All elements must be strings"
int_list = []
for label in labels:
int_list += [ord(s) for s in label]
int_list += [-1]
return int_list
@functools.lru_cache
def decode_labels(ints: List[int]):
"""Decode a list of int to a list of string, for example: [108, 49, -1, 76, 50, -1, 110, -1]
will be decoded as: ["l1", "L2", "n"]. Each number will be converted to a letter using chr()
function in Python, and -1 is used as delimiters to split strings.
:param ints: A list of int to be converted.
:return: A list of string.
"""
string_list = []
string = ""
for number in ints:
if number != -1:
string += chr(number)
else:
string_list.append(string)
string = ""
return string_list
import datetime
import logging
import os
import pickle
import time
import warnings
from collections import OrderedDict, defaultdict, deque
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from torch import nn
from terminaltables import AsciiTable
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value,
)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.__dict__["meters"]:
return self.__dict__["meters"][attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
logger = logging.getLogger(os.path.basename(os.getcwd()) + "." + __name__)
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"mem: {memory:.0f}",
"max mem: {max_memory:.0f}",
])
else:
log_msg = self.delimiter.join([
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
logger.info(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.memory_allocated() / MB,
max_memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
logger.info(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
class HighestCheckpoint:
def __init__(self, accelerator, model):
self.accelerate = accelerator
self.model = model
self.meters = {}
def update(self, **kwargs):
logger = get_logger(os.path.basename(os.getcwd()) + "." + __name__)
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
if k not in self.meters:
self.meters[k] = v
if v >= self.meters[k]:
self.meters.update({k: v})
save_path = os.path.join(self.accelerate.project_dir, f"best_{k}.pth")
model_state_dict = self.accelerate.get_state_dict(self.model, unwrap=True)
self.accelerate.save(model_state_dict, save_path)
logger.info(f"the best {k} model is saved to {save_path}")
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
"""Maintains moving averages of model parameters using an exponential decay.
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA.
"""
def __init__(self, model, decay, device="cpu"):
def ema_avg(avg_model_param, model_param, num_averaged):
return decay * avg_model_param + (1 - decay) * model_param
super().__init__(model, device, ema_avg, use_buffers=True)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def filter_mismatched_weights(model_state_dict, weight_state_dict):
mismatch_keys = {}
for key in list(model_state_dict.keys()):
if key in weight_state_dict:
value_model = model_state_dict[key]
value_state_dict = weight_state_dict[key]
if value_model.shape != value_state_dict.shape:
weight_state_dict[key] = value_model
mismatch_keys[key] = [value_model.shape, value_state_dict.shape]
return weight_state_dict, mismatch_keys
def load_checkpoint(file_name, map_location="cpu"):
if isinstance(file_name, str):
if file_name.startswith("http://") or file_name.startswith("https://"):
return torch.hub.load_state_dict_from_url(file_name, map_location=map_location)
elif os.path.exists(file_name):
return torch.load(file_name, map_location=map_location)
else:
warnings.warn("Given string, only url and local path of weight are supported, skip loading.")
return None
elif isinstance(file_name, OrderedDict):
return file_name
return None
def load_state_dict(model: nn.Module, state_dict: OrderedDict):
logger = logging.getLogger(os.path.basename(os.getcwd()) + "." + __name__)
if state_dict is None:
logger.warn("State dict is None, skip loading")
return
# load _classes_ for inference
if "_classes_" in state_dict:
dummy_classes = torch.zeros_like(state_dict["_classes_"])
model.register_buffer("_classes_", dummy_classes)
# initialize keys list
matched_state_dict, mismatch_keys = filter_mismatched_weights(model.state_dict(), state_dict)
incompatible_keys = model.load_state_dict(matched_state_dict, strict=False)
missing_keys = incompatible_keys.missing_keys
unexpected_keys = incompatible_keys.unexpected_keys
if len(mismatch_keys) == 0 and len(missing_keys) == 0 and len(unexpected_keys) == 0:
logger.info(incompatible_keys)
else:
logger.warning("The model and loaded state dict do not match exactly")
if len(missing_keys) != 0:
logger.warning(f"Missing keys: {', '.join(missing_keys)}\n")
if len(unexpected_keys) != 0:
logger.warning(f"Unexpected keys: {', '.join(unexpected_keys)}\n")
if len(mismatch_keys) != 0:
mismatch_tables = [["key name", "shape in model", "shape in state dict"]]
def format_shape(shape):
shape_str = "("
shape_str += ", ".join([str(i) for i in shape])
shape_str += ",)" if len(shape) == 1 else ")"
return shape_str
for key, (shape_model, shape_state_dict) in mismatch_keys.items():
mismatch_tables.append([key, format_shape(shape_model), format_shape(shape_state_dict)])
mismatch_tables = AsciiTable(mismatch_tables)
mismatch_tables.inner_footing_row_border = True
logger.warning(f"Size mismatch keys: {', '.join(mismatch_keys)}\n" + mismatch_tables.table + "\n")
import copy
import os
from functools import partial
from typing import List, Tuple, Union
import cv2
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets.coco import CocoDetection
def label_colormap(n_label=256, value=None):
"""Label colormap.
:param n_label: Number of labels, defaults to 256
:param value: Value scale or value of label color in HSV space, defaults to None
:return: Label id to colormap, numpy.ndarray, (N, 3), numpy.uint8
"""
def bitget(byteval, idx):
shape = byteval.shape + (8,)
return np.unpackbits(byteval).reshape(shape)[..., -1 - idx]
i = np.arange(n_label, dtype=np.uint8)
r = np.full_like(i, 0)
g = np.full_like(i, 0)
b = np.full_like(i, 0)
i = np.repeat(i[:, None], 8, axis=1)
i = np.right_shift(i, np.arange(0, 24, 3)).astype(np.uint8)
j = np.arange(8)[::-1]
r = np.bitwise_or.reduce(np.left_shift(bitget(i, 0), j), axis=1)
g = np.bitwise_or.reduce(np.left_shift(bitget(i, 1), j), axis=1)
b = np.bitwise_or.reduce(np.left_shift(bitget(i, 2), j), axis=1)
cmap = np.stack((r, g, b), axis=1).astype(np.uint8)
if value is not None:
hsv = cv2.cvtColor(cmap.reshape(1, -1, 3), cv2.COLOR_RGB2HSV)
if isinstance(value, float):
hsv[:, 1:, 2] = hsv[:, 1:, 2].astype(float) * value
else:
assert isinstance(value, int)
hsv[:, 1:, 2] = value
cmap = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).reshape(-1, 3)
return cmap
def generate_color_palette(n: int, contrast: bool = False):
colors = label_colormap(n)
hsv_colors = cv2.cvtColor(colors[None], cv2.COLOR_RGB2HSV)[0]
if not contrast:
return colors
# generate contrast lighter and darker colors
dark_colors = hsv_colors.copy()
dark_colors[:, -1] //= 2
light_colors = dark_colors.copy()
light_colors[:, -1] += 128
dark_colors = cv2.cvtColor(dark_colors[None], cv2.COLOR_HSV2RGB)[0]
light_colors = cv2.cvtColor(light_colors[None], cv2.COLOR_HSV2RGB)[0]
return colors, light_colors, dark_colors
def plot_bounding_boxes_on_image_cv2(
image: np.ndarray,
boxes: Union[np.ndarray, List[float]],
labels: Union[np.ndarray, List[int]],
scores: Union[np.ndarray, List[float]] = None,
classes: List[str] = None,
show_conf: float = 0.5,
font_scale: float = 1.0,
box_thick: int = 3,
fill_alpha: float = 0.2,
text_box_color: Tuple[int] = (255, 255, 255),
text_font_color: Tuple[int] = None,
text_alpha: float = 0.5,
):
"""Given an image, plot bounding boxes, labels on it.
:param image: input image with dtype uint8, format RGB and shape (h, w, c)
:param boxes: boxes with format (x1, y1, x2, y2) and shape (n, 4)
:param labels: label index with dtype int and shape (n,)
:param scores: confidence score with shape (n,), defaults to None
:param classes: a list containing all classes, label i will be converted
to classes[i] to show if given, else #i will be plotted, defaults to None
:param font_scale: scale factor to set font size, defaults to 1.0
:param box_thick: scale factor to set box border weight, defaults to 3
:param fill_alpha: alpha to filling the area in the bounding box, defaults to 0.2
:param text_box_color: background color of the text box, defaults to (255, 255, 255)
:param text_font_color: text color, will be set automatically if not given, defaults to None
:param text_alpha: alpha to filling the area in the text box, defaults to 0.5
"""
if len(labels) == 0:
return image
# convert to numpy array if given list as input
if any(not isinstance(t, np.ndarray) for t in (boxes, labels)):
boxes, labels = map(np.array, (boxes, labels))
if scores is not None and not isinstance(scores, np.ndarray):
scores = np.array(scores)
boxes = boxes.astype(np.int32) # convert to int32, compatible with cv2
# check input format for boxes, labels, class and scores
assert len(boxes) == len(labels), "The number of boxes and labels must be equal"
assert boxes.shape[-1] == 4, "Boxes must have 4 elements (x1, y1, x2, y2) and x2 > x1, y2 > y1"
assert classes is None or max(labels) <= len(classes) - 1, "#classes less than label index"
assert scores is None or len(scores) == len(labels), "#scores and #labels must be equal"
# filter low confident predictions
if scores is not None:
boxes, labels, scores = map(lambda x: x[scores > show_conf], (boxes, labels, scores))
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# get classes if not given
if classes is None:
classes = [str(i) for i in range(max(labels) + 1)]
# generate color palette
colors, light_colors, dark_colors = generate_color_palette(len(classes), contrast=True)
colors, light_colors, dark_colors = map(lambda x: x.tolist(), (colors, light_colors, dark_colors))
# map colors and labels to each bounding box
colors, light_colors, dark_colors = map(
lambda x: [x[i] for i in labels], (colors, light_colors, dark_colors)
)
labels = [classes[i] for i in labels]
# draw bounding boxes filling
original_image = copy.deepcopy(image)
image = copy.deepcopy(image)
for box, color in zip(boxes, colors):
cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), color=color, thickness=-1)
image = cv2.addWeighted(original_image, 1 - fill_alpha, image, fill_alpha, 0)
# draw label
original_image = copy.deepcopy(image)
for i, (color, label, box) in enumerate(zip(colors, labels, boxes)):
# get label text
if scores is not None:
label = f"{label}, {scores[i]:.3f}"
# calculate box region and baseline height
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, int(2 * font_scale))
label_size, baseline_height = [int(n) for n in label_size[0]], label_size[1]
# draw text box
box_left = box[0]
box_top = box[1] - label_size[1] - baseline_height - 3 # text_box is at the top of box
box_right = box[0] + label_size[0]
box_bottom = box[1] - 3
cv2.rectangle(image, (box_left, box_top), (box_right, box_bottom), color=text_box_color, thickness=-1)
# draw text label
font_color = text_font_color if text_font_color is not None else color
left, top = box_left, box[1] - baseline_height
label_size = int(2 * font_scale**1.5)
cv2.putText(image, label, (left, top), cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_color, label_size)
image = cv2.addWeighted(original_image, 1 - text_alpha, image, text_alpha, 0)
# draw bounding boxes with corner line
for dark_color, light_color, box in zip(dark_colors, light_colors, boxes):
# draw bounding boxes border
cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), color=dark_color, thickness=box_thick)
# calculate corner line length
if box[2] - box[0] <= 20 or box[3] - box[1] <= 20:
length = 1
else:
length = int(min(box[2] - box[0], box[3] - box[1]) * 0.2)
corner_color = light_color
# top left
cv2.line(image, (box[0], box[1]), (box[0] + length, box[1]), corner_color, thickness=box_thick)
cv2.line(image, (box[0], box[1]), (box[0], box[1] + length), corner_color, thickness=box_thick)
# top right
cv2.line(image, (box[2], box[1]), (box[2] - length, box[1]), corner_color, thickness=box_thick)
cv2.line(image, (box[2], box[1]), (box[2], box[1] + length), corner_color, thickness=box_thick)
# bottom left
cv2.line(image, (box[0], box[3]), (box[0] + length, box[3]), corner_color, thickness=box_thick)
cv2.line(image, (box[0], box[3]), (box[0], box[3] - length), corner_color, thickness=box_thick)
# bottom right
cv2.line(image, (box[2], box[3]), (box[2] - length, box[3]), corner_color, thickness=box_thick)
cv2.line(image, (box[2], box[3]), (box[2], box[3] - length), corner_color, thickness=box_thick)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def visualize_coco_bounding_boxes(
data_loader: DataLoader,
show_conf: float = 0.0,
show_dir: str = None,
font_scale: float = 1.0,
box_thick: int = 3,
fill_alpha: float = 0.2,
text_box_color: Tuple[int] = (255, 255, 255),
text_font_color: Tuple[int] = None,
text_alpha: float = 0.5,
):
"""Given a DataLoader of CocoDetection, plot bounding boxes, labels and save into given dir.
:param data_loader: DataLoader of CocoDetection.
:param show_conf: Only results with confidence > show_conf will be plot, defaults to 0.0
:param show_dir: directory to save visualization results, defaults to None
:param font_scale: scale factor to set font size, defaults to 1.0
:param box_thick: scale factor to set box border weight, defaults to 3
:param fill_alpha: alpha to filling the area in the bounding box, defaults to 0.2
:param text_box_color: background color of the text box, defaults to (255, 255, 255)
:param text_font_color: text color, will be set automatically if not given, defaults to None
:param text_alpha: alpha to filling the area in the text box, defaults to 0.5
"""
assert data_loader.batch_size in (None, 1), "batch_size of DataLoader for visualization must be 1"
assert isinstance(data_loader.dataset, CocoDetection), "Only CocoDetection dataset is supported"
os.makedirs(show_dir, exist_ok=True)
dataset: CocoDetection = data_loader.dataset
cat_ids = list(range(max(dataset.coco.cats.keys()) + 1))
classes = tuple(dataset.coco.cats.get(c, {"name": "none"})["name"] for c in cat_ids)
# multi-process on Windows does not support pickle local functions
# we use functools.partial on global functools to workaround it
data_loader.collate_fn = partial(
_visualize_batch_in_coco,
classes=classes,
show_conf=show_conf,
font_scale=font_scale,
box_thick=box_thick,
fill_alpha=fill_alpha,
text_box_color=text_box_color,
text_font_color=text_font_color,
text_alpha=text_alpha,
dataset=dataset,
show_dir=show_dir,
)
[None for _ in tqdm(data_loader)]
def _visualize_batch_in_coco(
batch: Tuple[np.ndarray, dict],
dataset: CocoDetection,
classes: List[str],
show_conf: float = 0.0,
show_dir: str = None,
font_scale: float = 1.0,
box_thick: int = 3,
fill_alpha: float = 0.2,
text_box_color: Tuple[int] = (255, 255, 255),
text_font_color: Tuple[int] = None,
text_alpha: float = 0.5,
):
image, output = batch[0]
# plot bounding boxes on image
image = image.numpy().transpose(1, 2, 0)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = plot_bounding_boxes_on_image_cv2(
image=image,
boxes=output["boxes"],
labels=output["labels"],
scores=output.get("scores", None),
classes=classes,
show_conf=show_conf,
font_scale=font_scale,
box_thick=box_thick,
fill_alpha=fill_alpha,
text_box_color=text_box_color,
text_font_color=text_font_color,
text_alpha=text_alpha,
)
image_name = dataset.coco.loadImgs([output["image_id"]])[0]["file_name"]
cv2.imwrite(os.path.join(show_dir, os.path.basename(image_name)), image)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment