Commit 3144257c authored by mashun1's avatar mashun1
Browse files

catvton

parents
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import os
from fvcore.common.timer import Timer
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode
from detectron2.utils.file_io import PathManager
from .builtin_meta import _get_coco_instances_meta
from .lvis_v0_5_categories import LVIS_CATEGORIES as LVIS_V0_5_CATEGORIES
from .lvis_v1_categories import LVIS_CATEGORIES as LVIS_V1_CATEGORIES
from .lvis_v1_category_image_count import LVIS_CATEGORY_IMAGE_COUNT as LVIS_V1_CATEGORY_IMAGE_COUNT
"""
This file contains functions to parse LVIS-format annotations into dicts in the
"Detectron2 format".
"""
logger = logging.getLogger(__name__)
__all__ = ["load_lvis_json", "register_lvis_instances", "get_lvis_instances_meta"]
def register_lvis_instances(name, metadata, json_file, image_root):
"""
Register a dataset in LVIS's json annotation format for instance detection and segmentation.
Args:
name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train".
metadata (dict): extra metadata associated with this dataset. It can be an empty dict.
json_file (str): path to the json instance annotation file.
image_root (str or path-like): directory which contains all the images.
"""
DatasetCatalog.register(name, lambda: load_lvis_json(json_file, image_root, name))
MetadataCatalog.get(name).set(
json_file=json_file, image_root=image_root, evaluator_type="lvis", **metadata
)
def load_lvis_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
"""
Load a json file in LVIS's annotation format.
Args:
json_file (str): full path to the LVIS json annotation file.
image_root (str): the directory where the images in this json file exists.
dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train").
If provided, this function will put "thing_classes" into the metadata
associated with this dataset.
extra_annotation_keys (list[str]): list of per-annotation keys that should also be
loaded into the dataset dict (besides "bbox", "bbox_mode", "category_id",
"segmentation"). The values for these keys will be returned as-is.
Returns:
list[dict]: a list of dicts in Detectron2 standard format. (See
`Using Custom Datasets </tutorials/datasets.html>`_ )
Notes:
1. This function does not read the image files.
The results do not have the "image" field.
"""
from lvis import LVIS
json_file = PathManager.get_local_path(json_file)
timer = Timer()
lvis_api = LVIS(json_file)
if timer.seconds() > 1:
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
if dataset_name is not None:
meta = get_lvis_instances_meta(dataset_name)
MetadataCatalog.get(dataset_name).set(**meta)
# sort indices for reproducible results
img_ids = sorted(lvis_api.imgs.keys())
# imgs is a list of dicts, each looks something like:
# {'license': 4,
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
# 'file_name': 'COCO_val2014_000000001268.jpg',
# 'height': 427,
# 'width': 640,
# 'date_captured': '2013-11-17 05:57:24',
# 'id': 1268}
imgs = lvis_api.load_imgs(img_ids)
# anns is a list[list[dict]], where each dict is an annotation
# record for an object. The inner list enumerates the objects in an image
# and the outer list enumerates over images. Example of anns[0]:
# [{'segmentation': [[192.81,
# 247.09,
# ...
# 219.03,
# 249.06]],
# 'area': 1035.749,
# 'image_id': 1268,
# 'bbox': [192.81, 224.8, 74.73, 33.43],
# 'category_id': 16,
# 'id': 42986},
# ...]
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
# Sanity check that each annotation has a unique id
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique".format(
json_file
)
imgs_anns = list(zip(imgs, anns))
logger.info("Loaded {} images in the LVIS format from {}".format(len(imgs_anns), json_file))
if extra_annotation_keys:
logger.info(
"The following extra annotation keys will be loaded: {} ".format(extra_annotation_keys)
)
else:
extra_annotation_keys = []
def get_file_name(img_root, img_dict):
# Determine the path including the split folder ("train2017", "val2017", "test2017") from
# the coco_url field. Example:
# 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
return os.path.join(img_root + split_folder, file_name)
dataset_dicts = []
for img_dict, anno_dict_list in imgs_anns:
record = {}
record["file_name"] = get_file_name(image_root, img_dict)
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
image_id = record["image_id"] = img_dict["id"]
objs = []
for anno in anno_dict_list:
# Check that the image_id in this annotation is the same as
# the image_id we're looking at.
# This fails only when the data parsing logic or the annotation file is buggy.
assert anno["image_id"] == image_id
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
# LVIS data loader can be used to load COCO dataset categories. In this case `meta`
# variable will have a field with COCO-specific category mapping.
if dataset_name is not None and "thing_dataset_id_to_contiguous_id" in meta:
obj["category_id"] = meta["thing_dataset_id_to_contiguous_id"][anno["category_id"]]
else:
obj["category_id"] = anno["category_id"] - 1 # Convert 1-indexed to 0-indexed
segm = anno["segmentation"] # list[list[float]]
# filter out invalid polygons (< 3 points)
valid_segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
assert len(segm) == len(
valid_segm
), "Annotation contains an invalid polygon with < 3 points"
assert len(segm) > 0
obj["segmentation"] = segm
for extra_ann_key in extra_annotation_keys:
obj[extra_ann_key] = anno[extra_ann_key]
objs.append(obj)
record["annotations"] = objs
dataset_dicts.append(record)
return dataset_dicts
def get_lvis_instances_meta(dataset_name):
"""
Load LVIS metadata.
Args:
dataset_name (str): LVIS dataset name without the split name (e.g., "lvis_v0.5").
Returns:
dict: LVIS metadata with keys: thing_classes
"""
if "cocofied" in dataset_name:
return _get_coco_instances_meta()
if "v0.5" in dataset_name:
return _get_lvis_instances_meta_v0_5()
elif "v1" in dataset_name:
return _get_lvis_instances_meta_v1()
raise ValueError("No built-in metadata for dataset {}".format(dataset_name))
def _get_lvis_instances_meta_v0_5():
assert len(LVIS_V0_5_CATEGORIES) == 1230
cat_ids = [k["id"] for k in LVIS_V0_5_CATEGORIES]
assert min(cat_ids) == 1 and max(cat_ids) == len(
cat_ids
), "Category ids are not in [1, #categories], as expected"
# Ensure that the category list is sorted by id
lvis_categories = sorted(LVIS_V0_5_CATEGORIES, key=lambda x: x["id"])
thing_classes = [k["synonyms"][0] for k in lvis_categories]
meta = {"thing_classes": thing_classes}
return meta
def _get_lvis_instances_meta_v1():
assert len(LVIS_V1_CATEGORIES) == 1203
cat_ids = [k["id"] for k in LVIS_V1_CATEGORIES]
assert min(cat_ids) == 1 and max(cat_ids) == len(
cat_ids
), "Category ids are not in [1, #categories], as expected"
# Ensure that the category list is sorted by id
lvis_categories = sorted(LVIS_V1_CATEGORIES, key=lambda x: x["id"])
thing_classes = [k["synonyms"][0] for k in lvis_categories]
meta = {
"thing_classes": thing_classes,
"class_image_count": LVIS_V1_CATEGORY_IMAGE_COUNT,
}
return meta
def main() -> None:
global logger
"""
Test the LVIS json dataset loader.
Usage:
python -m detectron2.data.datasets.lvis \
path/to/json path/to/image_root dataset_name vis_limit
"""
import sys
import detectron2.data.datasets # noqa # add pre-defined metadata
import numpy as np
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
from PIL import Image
logger = setup_logger(name=__name__)
meta = MetadataCatalog.get(sys.argv[3])
dicts = load_lvis_json(sys.argv[1], sys.argv[2], sys.argv[3])
logger.info("Done loading {} samples.".format(len(dicts)))
dirname = "lvis-data-vis"
os.makedirs(dirname, exist_ok=True)
for d in dicts[: int(sys.argv[4])]:
img = np.array(Image.open(d["file_name"]))
visualizer = Visualizer(img, metadata=meta)
vis = visualizer.draw_dataset_dict(d)
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
vis.save(fpath)
if __name__ == "__main__":
main() # pragma: no cover
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright (c) Facebook, Inc. and its affiliates.
# Autogen with
# with open("lvis_v1_train.json", "r") as f:
# a = json.load(f)
# c = a["categories"]
# for x in c:
# del x["name"]
# del x["instance_count"]
# del x["def"]
# del x["synonyms"]
# del x["frequency"]
# del x["synset"]
# LVIS_CATEGORY_IMAGE_COUNT = repr(c) + " # noqa"
# with open("/tmp/lvis_category_image_count.py", "wt") as f:
# f.write(f"LVIS_CATEGORY_IMAGE_COUNT = {LVIS_CATEGORY_IMAGE_COUNT}")
# Then paste the contents of that file below
# fmt: off
LVIS_CATEGORY_IMAGE_COUNT = [{'id': 1, 'image_count': 64}, {'id': 2, 'image_count': 364}, {'id': 3, 'image_count': 1911}, {'id': 4, 'image_count': 149}, {'id': 5, 'image_count': 29}, {'id': 6, 'image_count': 26}, {'id': 7, 'image_count': 59}, {'id': 8, 'image_count': 22}, {'id': 9, 'image_count': 12}, {'id': 10, 'image_count': 28}, {'id': 11, 'image_count': 505}, {'id': 12, 'image_count': 1207}, {'id': 13, 'image_count': 4}, {'id': 14, 'image_count': 10}, {'id': 15, 'image_count': 500}, {'id': 16, 'image_count': 33}, {'id': 17, 'image_count': 3}, {'id': 18, 'image_count': 44}, {'id': 19, 'image_count': 561}, {'id': 20, 'image_count': 8}, {'id': 21, 'image_count': 9}, {'id': 22, 'image_count': 33}, {'id': 23, 'image_count': 1883}, {'id': 24, 'image_count': 98}, {'id': 25, 'image_count': 70}, {'id': 26, 'image_count': 46}, {'id': 27, 'image_count': 117}, {'id': 28, 'image_count': 41}, {'id': 29, 'image_count': 1395}, {'id': 30, 'image_count': 7}, {'id': 31, 'image_count': 1}, {'id': 32, 'image_count': 314}, {'id': 33, 'image_count': 31}, {'id': 34, 'image_count': 1905}, {'id': 35, 'image_count': 1859}, {'id': 36, 'image_count': 1623}, {'id': 37, 'image_count': 47}, {'id': 38, 'image_count': 3}, {'id': 39, 'image_count': 3}, {'id': 40, 'image_count': 1}, {'id': 41, 'image_count': 305}, {'id': 42, 'image_count': 6}, {'id': 43, 'image_count': 210}, {'id': 44, 'image_count': 36}, {'id': 45, 'image_count': 1787}, {'id': 46, 'image_count': 17}, {'id': 47, 'image_count': 51}, {'id': 48, 'image_count': 138}, {'id': 49, 'image_count': 3}, {'id': 50, 'image_count': 1470}, {'id': 51, 'image_count': 3}, {'id': 52, 'image_count': 2}, {'id': 53, 'image_count': 186}, {'id': 54, 'image_count': 76}, {'id': 55, 'image_count': 26}, {'id': 56, 'image_count': 303}, {'id': 57, 'image_count': 738}, {'id': 58, 'image_count': 1799}, {'id': 59, 'image_count': 1934}, {'id': 60, 'image_count': 1609}, {'id': 61, 'image_count': 1622}, {'id': 62, 'image_count': 41}, {'id': 63, 'image_count': 4}, {'id': 64, 'image_count': 11}, {'id': 65, 'image_count': 270}, {'id': 66, 'image_count': 349}, {'id': 67, 'image_count': 42}, {'id': 68, 'image_count': 823}, {'id': 69, 'image_count': 6}, {'id': 70, 'image_count': 48}, {'id': 71, 'image_count': 3}, {'id': 72, 'image_count': 42}, {'id': 73, 'image_count': 24}, {'id': 74, 'image_count': 16}, {'id': 75, 'image_count': 605}, {'id': 76, 'image_count': 646}, {'id': 77, 'image_count': 1765}, {'id': 78, 'image_count': 2}, {'id': 79, 'image_count': 125}, {'id': 80, 'image_count': 1420}, {'id': 81, 'image_count': 140}, {'id': 82, 'image_count': 4}, {'id': 83, 'image_count': 322}, {'id': 84, 'image_count': 60}, {'id': 85, 'image_count': 2}, {'id': 86, 'image_count': 231}, {'id': 87, 'image_count': 333}, {'id': 88, 'image_count': 1941}, {'id': 89, 'image_count': 367}, {'id': 90, 'image_count': 1922}, {'id': 91, 'image_count': 18}, {'id': 92, 'image_count': 81}, {'id': 93, 'image_count': 1}, {'id': 94, 'image_count': 1852}, {'id': 95, 'image_count': 430}, {'id': 96, 'image_count': 247}, {'id': 97, 'image_count': 94}, {'id': 98, 'image_count': 21}, {'id': 99, 'image_count': 1821}, {'id': 100, 'image_count': 16}, {'id': 101, 'image_count': 12}, {'id': 102, 'image_count': 25}, {'id': 103, 'image_count': 41}, {'id': 104, 'image_count': 244}, {'id': 105, 'image_count': 7}, {'id': 106, 'image_count': 1}, {'id': 107, 'image_count': 40}, {'id': 108, 'image_count': 40}, {'id': 109, 'image_count': 104}, {'id': 110, 'image_count': 1671}, {'id': 111, 'image_count': 49}, {'id': 112, 'image_count': 243}, {'id': 113, 'image_count': 2}, {'id': 114, 'image_count': 242}, {'id': 115, 'image_count': 271}, {'id': 116, 'image_count': 104}, {'id': 117, 'image_count': 8}, {'id': 118, 'image_count': 1758}, {'id': 119, 'image_count': 1}, {'id': 120, 'image_count': 48}, {'id': 121, 'image_count': 14}, {'id': 122, 'image_count': 40}, {'id': 123, 'image_count': 1}, {'id': 124, 'image_count': 37}, {'id': 125, 'image_count': 1510}, {'id': 126, 'image_count': 6}, {'id': 127, 'image_count': 1903}, {'id': 128, 'image_count': 70}, {'id': 129, 'image_count': 86}, {'id': 130, 'image_count': 7}, {'id': 131, 'image_count': 5}, {'id': 132, 'image_count': 1406}, {'id': 133, 'image_count': 1901}, {'id': 134, 'image_count': 15}, {'id': 135, 'image_count': 28}, {'id': 136, 'image_count': 6}, {'id': 137, 'image_count': 494}, {'id': 138, 'image_count': 234}, {'id': 139, 'image_count': 1922}, {'id': 140, 'image_count': 1}, {'id': 141, 'image_count': 35}, {'id': 142, 'image_count': 5}, {'id': 143, 'image_count': 1828}, {'id': 144, 'image_count': 8}, {'id': 145, 'image_count': 63}, {'id': 146, 'image_count': 1668}, {'id': 147, 'image_count': 4}, {'id': 148, 'image_count': 95}, {'id': 149, 'image_count': 17}, {'id': 150, 'image_count': 1567}, {'id': 151, 'image_count': 2}, {'id': 152, 'image_count': 103}, {'id': 153, 'image_count': 50}, {'id': 154, 'image_count': 1309}, {'id': 155, 'image_count': 6}, {'id': 156, 'image_count': 92}, {'id': 157, 'image_count': 19}, {'id': 158, 'image_count': 37}, {'id': 159, 'image_count': 4}, {'id': 160, 'image_count': 709}, {'id': 161, 'image_count': 9}, {'id': 162, 'image_count': 82}, {'id': 163, 'image_count': 15}, {'id': 164, 'image_count': 3}, {'id': 165, 'image_count': 61}, {'id': 166, 'image_count': 51}, {'id': 167, 'image_count': 5}, {'id': 168, 'image_count': 13}, {'id': 169, 'image_count': 642}, {'id': 170, 'image_count': 24}, {'id': 171, 'image_count': 255}, {'id': 172, 'image_count': 9}, {'id': 173, 'image_count': 1808}, {'id': 174, 'image_count': 31}, {'id': 175, 'image_count': 158}, {'id': 176, 'image_count': 80}, {'id': 177, 'image_count': 1884}, {'id': 178, 'image_count': 158}, {'id': 179, 'image_count': 2}, {'id': 180, 'image_count': 12}, {'id': 181, 'image_count': 1659}, {'id': 182, 'image_count': 7}, {'id': 183, 'image_count': 834}, {'id': 184, 'image_count': 57}, {'id': 185, 'image_count': 174}, {'id': 186, 'image_count': 95}, {'id': 187, 'image_count': 27}, {'id': 188, 'image_count': 22}, {'id': 189, 'image_count': 1391}, {'id': 190, 'image_count': 90}, {'id': 191, 'image_count': 40}, {'id': 192, 'image_count': 445}, {'id': 193, 'image_count': 21}, {'id': 194, 'image_count': 1132}, {'id': 195, 'image_count': 177}, {'id': 196, 'image_count': 4}, {'id': 197, 'image_count': 17}, {'id': 198, 'image_count': 84}, {'id': 199, 'image_count': 55}, {'id': 200, 'image_count': 30}, {'id': 201, 'image_count': 25}, {'id': 202, 'image_count': 2}, {'id': 203, 'image_count': 125}, {'id': 204, 'image_count': 1135}, {'id': 205, 'image_count': 19}, {'id': 206, 'image_count': 72}, {'id': 207, 'image_count': 1926}, {'id': 208, 'image_count': 159}, {'id': 209, 'image_count': 7}, {'id': 210, 'image_count': 1}, {'id': 211, 'image_count': 13}, {'id': 212, 'image_count': 35}, {'id': 213, 'image_count': 18}, {'id': 214, 'image_count': 8}, {'id': 215, 'image_count': 6}, {'id': 216, 'image_count': 35}, {'id': 217, 'image_count': 1222}, {'id': 218, 'image_count': 103}, {'id': 219, 'image_count': 28}, {'id': 220, 'image_count': 63}, {'id': 221, 'image_count': 28}, {'id': 222, 'image_count': 5}, {'id': 223, 'image_count': 7}, {'id': 224, 'image_count': 14}, {'id': 225, 'image_count': 1918}, {'id': 226, 'image_count': 133}, {'id': 227, 'image_count': 16}, {'id': 228, 'image_count': 27}, {'id': 229, 'image_count': 110}, {'id': 230, 'image_count': 1895}, {'id': 231, 'image_count': 4}, {'id': 232, 'image_count': 1927}, {'id': 233, 'image_count': 8}, {'id': 234, 'image_count': 1}, {'id': 235, 'image_count': 263}, {'id': 236, 'image_count': 10}, {'id': 237, 'image_count': 2}, {'id': 238, 'image_count': 3}, {'id': 239, 'image_count': 87}, {'id': 240, 'image_count': 9}, {'id': 241, 'image_count': 71}, {'id': 242, 'image_count': 13}, {'id': 243, 'image_count': 18}, {'id': 244, 'image_count': 2}, {'id': 245, 'image_count': 5}, {'id': 246, 'image_count': 45}, {'id': 247, 'image_count': 1}, {'id': 248, 'image_count': 23}, {'id': 249, 'image_count': 32}, {'id': 250, 'image_count': 4}, {'id': 251, 'image_count': 1}, {'id': 252, 'image_count': 858}, {'id': 253, 'image_count': 661}, {'id': 254, 'image_count': 168}, {'id': 255, 'image_count': 210}, {'id': 256, 'image_count': 65}, {'id': 257, 'image_count': 4}, {'id': 258, 'image_count': 2}, {'id': 259, 'image_count': 159}, {'id': 260, 'image_count': 31}, {'id': 261, 'image_count': 811}, {'id': 262, 'image_count': 1}, {'id': 263, 'image_count': 42}, {'id': 264, 'image_count': 27}, {'id': 265, 'image_count': 2}, {'id': 266, 'image_count': 5}, {'id': 267, 'image_count': 95}, {'id': 268, 'image_count': 32}, {'id': 269, 'image_count': 1}, {'id': 270, 'image_count': 1}, {'id': 271, 'image_count': 1844}, {'id': 272, 'image_count': 897}, {'id': 273, 'image_count': 31}, {'id': 274, 'image_count': 23}, {'id': 275, 'image_count': 1}, {'id': 276, 'image_count': 202}, {'id': 277, 'image_count': 746}, {'id': 278, 'image_count': 44}, {'id': 279, 'image_count': 14}, {'id': 280, 'image_count': 26}, {'id': 281, 'image_count': 1}, {'id': 282, 'image_count': 2}, {'id': 283, 'image_count': 25}, {'id': 284, 'image_count': 238}, {'id': 285, 'image_count': 592}, {'id': 286, 'image_count': 26}, {'id': 287, 'image_count': 5}, {'id': 288, 'image_count': 42}, {'id': 289, 'image_count': 13}, {'id': 290, 'image_count': 46}, {'id': 291, 'image_count': 1}, {'id': 292, 'image_count': 8}, {'id': 293, 'image_count': 34}, {'id': 294, 'image_count': 5}, {'id': 295, 'image_count': 1}, {'id': 296, 'image_count': 1871}, {'id': 297, 'image_count': 717}, {'id': 298, 'image_count': 1010}, {'id': 299, 'image_count': 679}, {'id': 300, 'image_count': 3}, {'id': 301, 'image_count': 4}, {'id': 302, 'image_count': 1}, {'id': 303, 'image_count': 166}, {'id': 304, 'image_count': 2}, {'id': 305, 'image_count': 266}, {'id': 306, 'image_count': 101}, {'id': 307, 'image_count': 6}, {'id': 308, 'image_count': 14}, {'id': 309, 'image_count': 133}, {'id': 310, 'image_count': 2}, {'id': 311, 'image_count': 38}, {'id': 312, 'image_count': 95}, {'id': 313, 'image_count': 1}, {'id': 314, 'image_count': 12}, {'id': 315, 'image_count': 49}, {'id': 316, 'image_count': 5}, {'id': 317, 'image_count': 5}, {'id': 318, 'image_count': 16}, {'id': 319, 'image_count': 216}, {'id': 320, 'image_count': 12}, {'id': 321, 'image_count': 1}, {'id': 322, 'image_count': 54}, {'id': 323, 'image_count': 5}, {'id': 324, 'image_count': 245}, {'id': 325, 'image_count': 12}, {'id': 326, 'image_count': 7}, {'id': 327, 'image_count': 35}, {'id': 328, 'image_count': 36}, {'id': 329, 'image_count': 32}, {'id': 330, 'image_count': 1027}, {'id': 331, 'image_count': 10}, {'id': 332, 'image_count': 12}, {'id': 333, 'image_count': 1}, {'id': 334, 'image_count': 67}, {'id': 335, 'image_count': 71}, {'id': 336, 'image_count': 30}, {'id': 337, 'image_count': 48}, {'id': 338, 'image_count': 249}, {'id': 339, 'image_count': 13}, {'id': 340, 'image_count': 29}, {'id': 341, 'image_count': 14}, {'id': 342, 'image_count': 236}, {'id': 343, 'image_count': 15}, {'id': 344, 'image_count': 1521}, {'id': 345, 'image_count': 25}, {'id': 346, 'image_count': 249}, {'id': 347, 'image_count': 139}, {'id': 348, 'image_count': 2}, {'id': 349, 'image_count': 2}, {'id': 350, 'image_count': 1890}, {'id': 351, 'image_count': 1240}, {'id': 352, 'image_count': 1}, {'id': 353, 'image_count': 9}, {'id': 354, 'image_count': 1}, {'id': 355, 'image_count': 3}, {'id': 356, 'image_count': 11}, {'id': 357, 'image_count': 4}, {'id': 358, 'image_count': 236}, {'id': 359, 'image_count': 44}, {'id': 360, 'image_count': 19}, {'id': 361, 'image_count': 1100}, {'id': 362, 'image_count': 7}, {'id': 363, 'image_count': 69}, {'id': 364, 'image_count': 2}, {'id': 365, 'image_count': 8}, {'id': 366, 'image_count': 5}, {'id': 367, 'image_count': 227}, {'id': 368, 'image_count': 6}, {'id': 369, 'image_count': 106}, {'id': 370, 'image_count': 81}, {'id': 371, 'image_count': 17}, {'id': 372, 'image_count': 134}, {'id': 373, 'image_count': 312}, {'id': 374, 'image_count': 8}, {'id': 375, 'image_count': 271}, {'id': 376, 'image_count': 2}, {'id': 377, 'image_count': 103}, {'id': 378, 'image_count': 1938}, {'id': 379, 'image_count': 574}, {'id': 380, 'image_count': 120}, {'id': 381, 'image_count': 2}, {'id': 382, 'image_count': 2}, {'id': 383, 'image_count': 13}, {'id': 384, 'image_count': 29}, {'id': 385, 'image_count': 1710}, {'id': 386, 'image_count': 66}, {'id': 387, 'image_count': 1008}, {'id': 388, 'image_count': 1}, {'id': 389, 'image_count': 3}, {'id': 390, 'image_count': 1942}, {'id': 391, 'image_count': 19}, {'id': 392, 'image_count': 1488}, {'id': 393, 'image_count': 46}, {'id': 394, 'image_count': 106}, {'id': 395, 'image_count': 115}, {'id': 396, 'image_count': 19}, {'id': 397, 'image_count': 2}, {'id': 398, 'image_count': 1}, {'id': 399, 'image_count': 28}, {'id': 400, 'image_count': 9}, {'id': 401, 'image_count': 192}, {'id': 402, 'image_count': 12}, {'id': 403, 'image_count': 21}, {'id': 404, 'image_count': 247}, {'id': 405, 'image_count': 6}, {'id': 406, 'image_count': 64}, {'id': 407, 'image_count': 7}, {'id': 408, 'image_count': 40}, {'id': 409, 'image_count': 542}, {'id': 410, 'image_count': 2}, {'id': 411, 'image_count': 1898}, {'id': 412, 'image_count': 36}, {'id': 413, 'image_count': 4}, {'id': 414, 'image_count': 1}, {'id': 415, 'image_count': 191}, {'id': 416, 'image_count': 6}, {'id': 417, 'image_count': 41}, {'id': 418, 'image_count': 39}, {'id': 419, 'image_count': 46}, {'id': 420, 'image_count': 1}, {'id': 421, 'image_count': 1451}, {'id': 422, 'image_count': 1878}, {'id': 423, 'image_count': 11}, {'id': 424, 'image_count': 82}, {'id': 425, 'image_count': 18}, {'id': 426, 'image_count': 1}, {'id': 427, 'image_count': 7}, {'id': 428, 'image_count': 3}, {'id': 429, 'image_count': 575}, {'id': 430, 'image_count': 1907}, {'id': 431, 'image_count': 8}, {'id': 432, 'image_count': 4}, {'id': 433, 'image_count': 32}, {'id': 434, 'image_count': 11}, {'id': 435, 'image_count': 4}, {'id': 436, 'image_count': 54}, {'id': 437, 'image_count': 202}, {'id': 438, 'image_count': 32}, {'id': 439, 'image_count': 3}, {'id': 440, 'image_count': 130}, {'id': 441, 'image_count': 119}, {'id': 442, 'image_count': 141}, {'id': 443, 'image_count': 29}, {'id': 444, 'image_count': 525}, {'id': 445, 'image_count': 1323}, {'id': 446, 'image_count': 2}, {'id': 447, 'image_count': 113}, {'id': 448, 'image_count': 16}, {'id': 449, 'image_count': 7}, {'id': 450, 'image_count': 35}, {'id': 451, 'image_count': 1908}, {'id': 452, 'image_count': 353}, {'id': 453, 'image_count': 18}, {'id': 454, 'image_count': 14}, {'id': 455, 'image_count': 77}, {'id': 456, 'image_count': 8}, {'id': 457, 'image_count': 37}, {'id': 458, 'image_count': 1}, {'id': 459, 'image_count': 346}, {'id': 460, 'image_count': 19}, {'id': 461, 'image_count': 1779}, {'id': 462, 'image_count': 23}, {'id': 463, 'image_count': 25}, {'id': 464, 'image_count': 67}, {'id': 465, 'image_count': 19}, {'id': 466, 'image_count': 28}, {'id': 467, 'image_count': 4}, {'id': 468, 'image_count': 27}, {'id': 469, 'image_count': 1861}, {'id': 470, 'image_count': 11}, {'id': 471, 'image_count': 13}, {'id': 472, 'image_count': 13}, {'id': 473, 'image_count': 32}, {'id': 474, 'image_count': 1767}, {'id': 475, 'image_count': 42}, {'id': 476, 'image_count': 17}, {'id': 477, 'image_count': 128}, {'id': 478, 'image_count': 1}, {'id': 479, 'image_count': 9}, {'id': 480, 'image_count': 10}, {'id': 481, 'image_count': 4}, {'id': 482, 'image_count': 9}, {'id': 483, 'image_count': 18}, {'id': 484, 'image_count': 41}, {'id': 485, 'image_count': 28}, {'id': 486, 'image_count': 3}, {'id': 487, 'image_count': 65}, {'id': 488, 'image_count': 9}, {'id': 489, 'image_count': 23}, {'id': 490, 'image_count': 24}, {'id': 491, 'image_count': 1}, {'id': 492, 'image_count': 2}, {'id': 493, 'image_count': 59}, {'id': 494, 'image_count': 48}, {'id': 495, 'image_count': 17}, {'id': 496, 'image_count': 1877}, {'id': 497, 'image_count': 18}, {'id': 498, 'image_count': 1920}, {'id': 499, 'image_count': 50}, {'id': 500, 'image_count': 1890}, {'id': 501, 'image_count': 99}, {'id': 502, 'image_count': 1530}, {'id': 503, 'image_count': 3}, {'id': 504, 'image_count': 11}, {'id': 505, 'image_count': 19}, {'id': 506, 'image_count': 3}, {'id': 507, 'image_count': 63}, {'id': 508, 'image_count': 5}, {'id': 509, 'image_count': 6}, {'id': 510, 'image_count': 233}, {'id': 511, 'image_count': 54}, {'id': 512, 'image_count': 36}, {'id': 513, 'image_count': 10}, {'id': 514, 'image_count': 124}, {'id': 515, 'image_count': 101}, {'id': 516, 'image_count': 3}, {'id': 517, 'image_count': 363}, {'id': 518, 'image_count': 3}, {'id': 519, 'image_count': 30}, {'id': 520, 'image_count': 18}, {'id': 521, 'image_count': 199}, {'id': 522, 'image_count': 97}, {'id': 523, 'image_count': 32}, {'id': 524, 'image_count': 121}, {'id': 525, 'image_count': 16}, {'id': 526, 'image_count': 12}, {'id': 527, 'image_count': 2}, {'id': 528, 'image_count': 214}, {'id': 529, 'image_count': 48}, {'id': 530, 'image_count': 26}, {'id': 531, 'image_count': 13}, {'id': 532, 'image_count': 4}, {'id': 533, 'image_count': 11}, {'id': 534, 'image_count': 123}, {'id': 535, 'image_count': 7}, {'id': 536, 'image_count': 200}, {'id': 537, 'image_count': 91}, {'id': 538, 'image_count': 9}, {'id': 539, 'image_count': 72}, {'id': 540, 'image_count': 1886}, {'id': 541, 'image_count': 4}, {'id': 542, 'image_count': 1}, {'id': 543, 'image_count': 1}, {'id': 544, 'image_count': 1932}, {'id': 545, 'image_count': 4}, {'id': 546, 'image_count': 56}, {'id': 547, 'image_count': 854}, {'id': 548, 'image_count': 755}, {'id': 549, 'image_count': 1843}, {'id': 550, 'image_count': 96}, {'id': 551, 'image_count': 7}, {'id': 552, 'image_count': 74}, {'id': 553, 'image_count': 66}, {'id': 554, 'image_count': 57}, {'id': 555, 'image_count': 44}, {'id': 556, 'image_count': 1905}, {'id': 557, 'image_count': 4}, {'id': 558, 'image_count': 90}, {'id': 559, 'image_count': 1635}, {'id': 560, 'image_count': 8}, {'id': 561, 'image_count': 5}, {'id': 562, 'image_count': 50}, {'id': 563, 'image_count': 545}, {'id': 564, 'image_count': 20}, {'id': 565, 'image_count': 193}, {'id': 566, 'image_count': 285}, {'id': 567, 'image_count': 3}, {'id': 568, 'image_count': 1}, {'id': 569, 'image_count': 1904}, {'id': 570, 'image_count': 294}, {'id': 571, 'image_count': 3}, {'id': 572, 'image_count': 5}, {'id': 573, 'image_count': 24}, {'id': 574, 'image_count': 2}, {'id': 575, 'image_count': 2}, {'id': 576, 'image_count': 16}, {'id': 577, 'image_count': 8}, {'id': 578, 'image_count': 154}, {'id': 579, 'image_count': 66}, {'id': 580, 'image_count': 1}, {'id': 581, 'image_count': 24}, {'id': 582, 'image_count': 1}, {'id': 583, 'image_count': 4}, {'id': 584, 'image_count': 75}, {'id': 585, 'image_count': 6}, {'id': 586, 'image_count': 126}, {'id': 587, 'image_count': 24}, {'id': 588, 'image_count': 22}, {'id': 589, 'image_count': 1872}, {'id': 590, 'image_count': 16}, {'id': 591, 'image_count': 423}, {'id': 592, 'image_count': 1927}, {'id': 593, 'image_count': 38}, {'id': 594, 'image_count': 3}, {'id': 595, 'image_count': 1945}, {'id': 596, 'image_count': 35}, {'id': 597, 'image_count': 1}, {'id': 598, 'image_count': 13}, {'id': 599, 'image_count': 9}, {'id': 600, 'image_count': 14}, {'id': 601, 'image_count': 37}, {'id': 602, 'image_count': 3}, {'id': 603, 'image_count': 4}, {'id': 604, 'image_count': 100}, {'id': 605, 'image_count': 195}, {'id': 606, 'image_count': 1}, {'id': 607, 'image_count': 12}, {'id': 608, 'image_count': 24}, {'id': 609, 'image_count': 489}, {'id': 610, 'image_count': 10}, {'id': 611, 'image_count': 1689}, {'id': 612, 'image_count': 42}, {'id': 613, 'image_count': 81}, {'id': 614, 'image_count': 894}, {'id': 615, 'image_count': 1868}, {'id': 616, 'image_count': 7}, {'id': 617, 'image_count': 1567}, {'id': 618, 'image_count': 10}, {'id': 619, 'image_count': 8}, {'id': 620, 'image_count': 7}, {'id': 621, 'image_count': 629}, {'id': 622, 'image_count': 89}, {'id': 623, 'image_count': 15}, {'id': 624, 'image_count': 134}, {'id': 625, 'image_count': 4}, {'id': 626, 'image_count': 1802}, {'id': 627, 'image_count': 595}, {'id': 628, 'image_count': 1210}, {'id': 629, 'image_count': 48}, {'id': 630, 'image_count': 418}, {'id': 631, 'image_count': 1846}, {'id': 632, 'image_count': 5}, {'id': 633, 'image_count': 221}, {'id': 634, 'image_count': 10}, {'id': 635, 'image_count': 7}, {'id': 636, 'image_count': 76}, {'id': 637, 'image_count': 22}, {'id': 638, 'image_count': 10}, {'id': 639, 'image_count': 341}, {'id': 640, 'image_count': 1}, {'id': 641, 'image_count': 705}, {'id': 642, 'image_count': 1900}, {'id': 643, 'image_count': 188}, {'id': 644, 'image_count': 227}, {'id': 645, 'image_count': 861}, {'id': 646, 'image_count': 6}, {'id': 647, 'image_count': 115}, {'id': 648, 'image_count': 5}, {'id': 649, 'image_count': 43}, {'id': 650, 'image_count': 14}, {'id': 651, 'image_count': 6}, {'id': 652, 'image_count': 15}, {'id': 653, 'image_count': 1167}, {'id': 654, 'image_count': 15}, {'id': 655, 'image_count': 994}, {'id': 656, 'image_count': 28}, {'id': 657, 'image_count': 2}, {'id': 658, 'image_count': 338}, {'id': 659, 'image_count': 334}, {'id': 660, 'image_count': 15}, {'id': 661, 'image_count': 102}, {'id': 662, 'image_count': 1}, {'id': 663, 'image_count': 8}, {'id': 664, 'image_count': 1}, {'id': 665, 'image_count': 1}, {'id': 666, 'image_count': 28}, {'id': 667, 'image_count': 91}, {'id': 668, 'image_count': 260}, {'id': 669, 'image_count': 131}, {'id': 670, 'image_count': 128}, {'id': 671, 'image_count': 3}, {'id': 672, 'image_count': 10}, {'id': 673, 'image_count': 39}, {'id': 674, 'image_count': 2}, {'id': 675, 'image_count': 925}, {'id': 676, 'image_count': 354}, {'id': 677, 'image_count': 31}, {'id': 678, 'image_count': 10}, {'id': 679, 'image_count': 215}, {'id': 680, 'image_count': 71}, {'id': 681, 'image_count': 43}, {'id': 682, 'image_count': 28}, {'id': 683, 'image_count': 34}, {'id': 684, 'image_count': 16}, {'id': 685, 'image_count': 273}, {'id': 686, 'image_count': 2}, {'id': 687, 'image_count': 999}, {'id': 688, 'image_count': 4}, {'id': 689, 'image_count': 107}, {'id': 690, 'image_count': 2}, {'id': 691, 'image_count': 1}, {'id': 692, 'image_count': 454}, {'id': 693, 'image_count': 9}, {'id': 694, 'image_count': 1901}, {'id': 695, 'image_count': 61}, {'id': 696, 'image_count': 91}, {'id': 697, 'image_count': 46}, {'id': 698, 'image_count': 1402}, {'id': 699, 'image_count': 74}, {'id': 700, 'image_count': 421}, {'id': 701, 'image_count': 226}, {'id': 702, 'image_count': 10}, {'id': 703, 'image_count': 1720}, {'id': 704, 'image_count': 261}, {'id': 705, 'image_count': 1337}, {'id': 706, 'image_count': 293}, {'id': 707, 'image_count': 62}, {'id': 708, 'image_count': 814}, {'id': 709, 'image_count': 407}, {'id': 710, 'image_count': 6}, {'id': 711, 'image_count': 16}, {'id': 712, 'image_count': 7}, {'id': 713, 'image_count': 1791}, {'id': 714, 'image_count': 2}, {'id': 715, 'image_count': 1915}, {'id': 716, 'image_count': 1940}, {'id': 717, 'image_count': 13}, {'id': 718, 'image_count': 16}, {'id': 719, 'image_count': 448}, {'id': 720, 'image_count': 12}, {'id': 721, 'image_count': 18}, {'id': 722, 'image_count': 4}, {'id': 723, 'image_count': 71}, {'id': 724, 'image_count': 189}, {'id': 725, 'image_count': 74}, {'id': 726, 'image_count': 103}, {'id': 727, 'image_count': 3}, {'id': 728, 'image_count': 110}, {'id': 729, 'image_count': 5}, {'id': 730, 'image_count': 9}, {'id': 731, 'image_count': 15}, {'id': 732, 'image_count': 25}, {'id': 733, 'image_count': 7}, {'id': 734, 'image_count': 647}, {'id': 735, 'image_count': 824}, {'id': 736, 'image_count': 100}, {'id': 737, 'image_count': 47}, {'id': 738, 'image_count': 121}, {'id': 739, 'image_count': 731}, {'id': 740, 'image_count': 73}, {'id': 741, 'image_count': 49}, {'id': 742, 'image_count': 23}, {'id': 743, 'image_count': 4}, {'id': 744, 'image_count': 62}, {'id': 745, 'image_count': 118}, {'id': 746, 'image_count': 99}, {'id': 747, 'image_count': 40}, {'id': 748, 'image_count': 1036}, {'id': 749, 'image_count': 105}, {'id': 750, 'image_count': 21}, {'id': 751, 'image_count': 229}, {'id': 752, 'image_count': 7}, {'id': 753, 'image_count': 72}, {'id': 754, 'image_count': 9}, {'id': 755, 'image_count': 10}, {'id': 756, 'image_count': 328}, {'id': 757, 'image_count': 468}, {'id': 758, 'image_count': 1}, {'id': 759, 'image_count': 2}, {'id': 760, 'image_count': 24}, {'id': 761, 'image_count': 11}, {'id': 762, 'image_count': 72}, {'id': 763, 'image_count': 17}, {'id': 764, 'image_count': 10}, {'id': 765, 'image_count': 17}, {'id': 766, 'image_count': 489}, {'id': 767, 'image_count': 47}, {'id': 768, 'image_count': 93}, {'id': 769, 'image_count': 1}, {'id': 770, 'image_count': 12}, {'id': 771, 'image_count': 228}, {'id': 772, 'image_count': 5}, {'id': 773, 'image_count': 76}, {'id': 774, 'image_count': 71}, {'id': 775, 'image_count': 30}, {'id': 776, 'image_count': 109}, {'id': 777, 'image_count': 14}, {'id': 778, 'image_count': 1}, {'id': 779, 'image_count': 8}, {'id': 780, 'image_count': 26}, {'id': 781, 'image_count': 339}, {'id': 782, 'image_count': 153}, {'id': 783, 'image_count': 2}, {'id': 784, 'image_count': 3}, {'id': 785, 'image_count': 8}, {'id': 786, 'image_count': 47}, {'id': 787, 'image_count': 8}, {'id': 788, 'image_count': 6}, {'id': 789, 'image_count': 116}, {'id': 790, 'image_count': 69}, {'id': 791, 'image_count': 13}, {'id': 792, 'image_count': 6}, {'id': 793, 'image_count': 1928}, {'id': 794, 'image_count': 79}, {'id': 795, 'image_count': 14}, {'id': 796, 'image_count': 7}, {'id': 797, 'image_count': 20}, {'id': 798, 'image_count': 114}, {'id': 799, 'image_count': 221}, {'id': 800, 'image_count': 502}, {'id': 801, 'image_count': 62}, {'id': 802, 'image_count': 87}, {'id': 803, 'image_count': 4}, {'id': 804, 'image_count': 1912}, {'id': 805, 'image_count': 7}, {'id': 806, 'image_count': 186}, {'id': 807, 'image_count': 18}, {'id': 808, 'image_count': 4}, {'id': 809, 'image_count': 3}, {'id': 810, 'image_count': 7}, {'id': 811, 'image_count': 1413}, {'id': 812, 'image_count': 7}, {'id': 813, 'image_count': 12}, {'id': 814, 'image_count': 248}, {'id': 815, 'image_count': 4}, {'id': 816, 'image_count': 1881}, {'id': 817, 'image_count': 529}, {'id': 818, 'image_count': 1932}, {'id': 819, 'image_count': 50}, {'id': 820, 'image_count': 3}, {'id': 821, 'image_count': 28}, {'id': 822, 'image_count': 10}, {'id': 823, 'image_count': 5}, {'id': 824, 'image_count': 5}, {'id': 825, 'image_count': 18}, {'id': 826, 'image_count': 14}, {'id': 827, 'image_count': 1890}, {'id': 828, 'image_count': 660}, {'id': 829, 'image_count': 8}, {'id': 830, 'image_count': 25}, {'id': 831, 'image_count': 10}, {'id': 832, 'image_count': 218}, {'id': 833, 'image_count': 36}, {'id': 834, 'image_count': 16}, {'id': 835, 'image_count': 808}, {'id': 836, 'image_count': 479}, {'id': 837, 'image_count': 1404}, {'id': 838, 'image_count': 307}, {'id': 839, 'image_count': 57}, {'id': 840, 'image_count': 28}, {'id': 841, 'image_count': 80}, {'id': 842, 'image_count': 11}, {'id': 843, 'image_count': 92}, {'id': 844, 'image_count': 20}, {'id': 845, 'image_count': 194}, {'id': 846, 'image_count': 23}, {'id': 847, 'image_count': 52}, {'id': 848, 'image_count': 673}, {'id': 849, 'image_count': 2}, {'id': 850, 'image_count': 2}, {'id': 851, 'image_count': 1}, {'id': 852, 'image_count': 2}, {'id': 853, 'image_count': 8}, {'id': 854, 'image_count': 80}, {'id': 855, 'image_count': 3}, {'id': 856, 'image_count': 3}, {'id': 857, 'image_count': 15}, {'id': 858, 'image_count': 2}, {'id': 859, 'image_count': 10}, {'id': 860, 'image_count': 386}, {'id': 861, 'image_count': 65}, {'id': 862, 'image_count': 3}, {'id': 863, 'image_count': 35}, {'id': 864, 'image_count': 5}, {'id': 865, 'image_count': 180}, {'id': 866, 'image_count': 99}, {'id': 867, 'image_count': 49}, {'id': 868, 'image_count': 28}, {'id': 869, 'image_count': 1}, {'id': 870, 'image_count': 52}, {'id': 871, 'image_count': 36}, {'id': 872, 'image_count': 70}, {'id': 873, 'image_count': 6}, {'id': 874, 'image_count': 29}, {'id': 875, 'image_count': 24}, {'id': 876, 'image_count': 1115}, {'id': 877, 'image_count': 61}, {'id': 878, 'image_count': 18}, {'id': 879, 'image_count': 18}, {'id': 880, 'image_count': 665}, {'id': 881, 'image_count': 1096}, {'id': 882, 'image_count': 29}, {'id': 883, 'image_count': 8}, {'id': 884, 'image_count': 14}, {'id': 885, 'image_count': 1622}, {'id': 886, 'image_count': 2}, {'id': 887, 'image_count': 3}, {'id': 888, 'image_count': 32}, {'id': 889, 'image_count': 55}, {'id': 890, 'image_count': 1}, {'id': 891, 'image_count': 10}, {'id': 892, 'image_count': 10}, {'id': 893, 'image_count': 47}, {'id': 894, 'image_count': 3}, {'id': 895, 'image_count': 29}, {'id': 896, 'image_count': 342}, {'id': 897, 'image_count': 25}, {'id': 898, 'image_count': 1469}, {'id': 899, 'image_count': 521}, {'id': 900, 'image_count': 347}, {'id': 901, 'image_count': 35}, {'id': 902, 'image_count': 7}, {'id': 903, 'image_count': 207}, {'id': 904, 'image_count': 108}, {'id': 905, 'image_count': 2}, {'id': 906, 'image_count': 34}, {'id': 907, 'image_count': 12}, {'id': 908, 'image_count': 10}, {'id': 909, 'image_count': 13}, {'id': 910, 'image_count': 361}, {'id': 911, 'image_count': 1023}, {'id': 912, 'image_count': 782}, {'id': 913, 'image_count': 2}, {'id': 914, 'image_count': 5}, {'id': 915, 'image_count': 247}, {'id': 916, 'image_count': 221}, {'id': 917, 'image_count': 4}, {'id': 918, 'image_count': 8}, {'id': 919, 'image_count': 158}, {'id': 920, 'image_count': 3}, {'id': 921, 'image_count': 752}, {'id': 922, 'image_count': 64}, {'id': 923, 'image_count': 707}, {'id': 924, 'image_count': 143}, {'id': 925, 'image_count': 1}, {'id': 926, 'image_count': 49}, {'id': 927, 'image_count': 126}, {'id': 928, 'image_count': 76}, {'id': 929, 'image_count': 11}, {'id': 930, 'image_count': 11}, {'id': 931, 'image_count': 4}, {'id': 932, 'image_count': 39}, {'id': 933, 'image_count': 11}, {'id': 934, 'image_count': 13}, {'id': 935, 'image_count': 91}, {'id': 936, 'image_count': 14}, {'id': 937, 'image_count': 5}, {'id': 938, 'image_count': 3}, {'id': 939, 'image_count': 10}, {'id': 940, 'image_count': 18}, {'id': 941, 'image_count': 9}, {'id': 942, 'image_count': 6}, {'id': 943, 'image_count': 951}, {'id': 944, 'image_count': 2}, {'id': 945, 'image_count': 1}, {'id': 946, 'image_count': 19}, {'id': 947, 'image_count': 1942}, {'id': 948, 'image_count': 1916}, {'id': 949, 'image_count': 139}, {'id': 950, 'image_count': 43}, {'id': 951, 'image_count': 1969}, {'id': 952, 'image_count': 5}, {'id': 953, 'image_count': 134}, {'id': 954, 'image_count': 74}, {'id': 955, 'image_count': 381}, {'id': 956, 'image_count': 1}, {'id': 957, 'image_count': 381}, {'id': 958, 'image_count': 6}, {'id': 959, 'image_count': 1826}, {'id': 960, 'image_count': 28}, {'id': 961, 'image_count': 1635}, {'id': 962, 'image_count': 1967}, {'id': 963, 'image_count': 16}, {'id': 964, 'image_count': 1926}, {'id': 965, 'image_count': 1789}, {'id': 966, 'image_count': 401}, {'id': 967, 'image_count': 1968}, {'id': 968, 'image_count': 1167}, {'id': 969, 'image_count': 1}, {'id': 970, 'image_count': 56}, {'id': 971, 'image_count': 17}, {'id': 972, 'image_count': 1}, {'id': 973, 'image_count': 58}, {'id': 974, 'image_count': 9}, {'id': 975, 'image_count': 8}, {'id': 976, 'image_count': 1124}, {'id': 977, 'image_count': 31}, {'id': 978, 'image_count': 16}, {'id': 979, 'image_count': 491}, {'id': 980, 'image_count': 432}, {'id': 981, 'image_count': 1945}, {'id': 982, 'image_count': 1899}, {'id': 983, 'image_count': 5}, {'id': 984, 'image_count': 28}, {'id': 985, 'image_count': 7}, {'id': 986, 'image_count': 146}, {'id': 987, 'image_count': 1}, {'id': 988, 'image_count': 25}, {'id': 989, 'image_count': 22}, {'id': 990, 'image_count': 1}, {'id': 991, 'image_count': 10}, {'id': 992, 'image_count': 9}, {'id': 993, 'image_count': 308}, {'id': 994, 'image_count': 4}, {'id': 995, 'image_count': 1969}, {'id': 996, 'image_count': 45}, {'id': 997, 'image_count': 12}, {'id': 998, 'image_count': 1}, {'id': 999, 'image_count': 85}, {'id': 1000, 'image_count': 1127}, {'id': 1001, 'image_count': 11}, {'id': 1002, 'image_count': 60}, {'id': 1003, 'image_count': 1}, {'id': 1004, 'image_count': 16}, {'id': 1005, 'image_count': 1}, {'id': 1006, 'image_count': 65}, {'id': 1007, 'image_count': 13}, {'id': 1008, 'image_count': 655}, {'id': 1009, 'image_count': 51}, {'id': 1010, 'image_count': 1}, {'id': 1011, 'image_count': 673}, {'id': 1012, 'image_count': 5}, {'id': 1013, 'image_count': 36}, {'id': 1014, 'image_count': 54}, {'id': 1015, 'image_count': 5}, {'id': 1016, 'image_count': 8}, {'id': 1017, 'image_count': 305}, {'id': 1018, 'image_count': 297}, {'id': 1019, 'image_count': 1053}, {'id': 1020, 'image_count': 223}, {'id': 1021, 'image_count': 1037}, {'id': 1022, 'image_count': 63}, {'id': 1023, 'image_count': 1881}, {'id': 1024, 'image_count': 507}, {'id': 1025, 'image_count': 333}, {'id': 1026, 'image_count': 1911}, {'id': 1027, 'image_count': 1765}, {'id': 1028, 'image_count': 1}, {'id': 1029, 'image_count': 5}, {'id': 1030, 'image_count': 1}, {'id': 1031, 'image_count': 9}, {'id': 1032, 'image_count': 2}, {'id': 1033, 'image_count': 151}, {'id': 1034, 'image_count': 82}, {'id': 1035, 'image_count': 1931}, {'id': 1036, 'image_count': 41}, {'id': 1037, 'image_count': 1895}, {'id': 1038, 'image_count': 24}, {'id': 1039, 'image_count': 22}, {'id': 1040, 'image_count': 35}, {'id': 1041, 'image_count': 69}, {'id': 1042, 'image_count': 962}, {'id': 1043, 'image_count': 588}, {'id': 1044, 'image_count': 21}, {'id': 1045, 'image_count': 825}, {'id': 1046, 'image_count': 52}, {'id': 1047, 'image_count': 5}, {'id': 1048, 'image_count': 5}, {'id': 1049, 'image_count': 5}, {'id': 1050, 'image_count': 1860}, {'id': 1051, 'image_count': 56}, {'id': 1052, 'image_count': 1582}, {'id': 1053, 'image_count': 7}, {'id': 1054, 'image_count': 2}, {'id': 1055, 'image_count': 1562}, {'id': 1056, 'image_count': 1885}, {'id': 1057, 'image_count': 1}, {'id': 1058, 'image_count': 5}, {'id': 1059, 'image_count': 137}, {'id': 1060, 'image_count': 1094}, {'id': 1061, 'image_count': 134}, {'id': 1062, 'image_count': 29}, {'id': 1063, 'image_count': 22}, {'id': 1064, 'image_count': 522}, {'id': 1065, 'image_count': 50}, {'id': 1066, 'image_count': 68}, {'id': 1067, 'image_count': 16}, {'id': 1068, 'image_count': 40}, {'id': 1069, 'image_count': 35}, {'id': 1070, 'image_count': 135}, {'id': 1071, 'image_count': 1413}, {'id': 1072, 'image_count': 772}, {'id': 1073, 'image_count': 50}, {'id': 1074, 'image_count': 1015}, {'id': 1075, 'image_count': 1}, {'id': 1076, 'image_count': 65}, {'id': 1077, 'image_count': 1900}, {'id': 1078, 'image_count': 1302}, {'id': 1079, 'image_count': 1977}, {'id': 1080, 'image_count': 2}, {'id': 1081, 'image_count': 29}, {'id': 1082, 'image_count': 36}, {'id': 1083, 'image_count': 138}, {'id': 1084, 'image_count': 4}, {'id': 1085, 'image_count': 67}, {'id': 1086, 'image_count': 26}, {'id': 1087, 'image_count': 25}, {'id': 1088, 'image_count': 33}, {'id': 1089, 'image_count': 37}, {'id': 1090, 'image_count': 50}, {'id': 1091, 'image_count': 270}, {'id': 1092, 'image_count': 12}, {'id': 1093, 'image_count': 316}, {'id': 1094, 'image_count': 41}, {'id': 1095, 'image_count': 224}, {'id': 1096, 'image_count': 105}, {'id': 1097, 'image_count': 1925}, {'id': 1098, 'image_count': 1021}, {'id': 1099, 'image_count': 1213}, {'id': 1100, 'image_count': 172}, {'id': 1101, 'image_count': 28}, {'id': 1102, 'image_count': 745}, {'id': 1103, 'image_count': 187}, {'id': 1104, 'image_count': 147}, {'id': 1105, 'image_count': 136}, {'id': 1106, 'image_count': 34}, {'id': 1107, 'image_count': 41}, {'id': 1108, 'image_count': 636}, {'id': 1109, 'image_count': 570}, {'id': 1110, 'image_count': 1149}, {'id': 1111, 'image_count': 61}, {'id': 1112, 'image_count': 1890}, {'id': 1113, 'image_count': 18}, {'id': 1114, 'image_count': 143}, {'id': 1115, 'image_count': 1517}, {'id': 1116, 'image_count': 7}, {'id': 1117, 'image_count': 943}, {'id': 1118, 'image_count': 6}, {'id': 1119, 'image_count': 1}, {'id': 1120, 'image_count': 11}, {'id': 1121, 'image_count': 101}, {'id': 1122, 'image_count': 1909}, {'id': 1123, 'image_count': 800}, {'id': 1124, 'image_count': 1}, {'id': 1125, 'image_count': 44}, {'id': 1126, 'image_count': 3}, {'id': 1127, 'image_count': 44}, {'id': 1128, 'image_count': 31}, {'id': 1129, 'image_count': 7}, {'id': 1130, 'image_count': 20}, {'id': 1131, 'image_count': 11}, {'id': 1132, 'image_count': 13}, {'id': 1133, 'image_count': 1924}, {'id': 1134, 'image_count': 113}, {'id': 1135, 'image_count': 2}, {'id': 1136, 'image_count': 139}, {'id': 1137, 'image_count': 12}, {'id': 1138, 'image_count': 37}, {'id': 1139, 'image_count': 1866}, {'id': 1140, 'image_count': 47}, {'id': 1141, 'image_count': 1468}, {'id': 1142, 'image_count': 729}, {'id': 1143, 'image_count': 24}, {'id': 1144, 'image_count': 1}, {'id': 1145, 'image_count': 10}, {'id': 1146, 'image_count': 3}, {'id': 1147, 'image_count': 14}, {'id': 1148, 'image_count': 4}, {'id': 1149, 'image_count': 29}, {'id': 1150, 'image_count': 4}, {'id': 1151, 'image_count': 70}, {'id': 1152, 'image_count': 46}, {'id': 1153, 'image_count': 14}, {'id': 1154, 'image_count': 48}, {'id': 1155, 'image_count': 1855}, {'id': 1156, 'image_count': 113}, {'id': 1157, 'image_count': 1}, {'id': 1158, 'image_count': 1}, {'id': 1159, 'image_count': 10}, {'id': 1160, 'image_count': 54}, {'id': 1161, 'image_count': 1923}, {'id': 1162, 'image_count': 630}, {'id': 1163, 'image_count': 31}, {'id': 1164, 'image_count': 69}, {'id': 1165, 'image_count': 7}, {'id': 1166, 'image_count': 11}, {'id': 1167, 'image_count': 1}, {'id': 1168, 'image_count': 30}, {'id': 1169, 'image_count': 50}, {'id': 1170, 'image_count': 45}, {'id': 1171, 'image_count': 28}, {'id': 1172, 'image_count': 114}, {'id': 1173, 'image_count': 193}, {'id': 1174, 'image_count': 21}, {'id': 1175, 'image_count': 91}, {'id': 1176, 'image_count': 31}, {'id': 1177, 'image_count': 1469}, {'id': 1178, 'image_count': 1924}, {'id': 1179, 'image_count': 87}, {'id': 1180, 'image_count': 77}, {'id': 1181, 'image_count': 11}, {'id': 1182, 'image_count': 47}, {'id': 1183, 'image_count': 21}, {'id': 1184, 'image_count': 47}, {'id': 1185, 'image_count': 70}, {'id': 1186, 'image_count': 1838}, {'id': 1187, 'image_count': 19}, {'id': 1188, 'image_count': 531}, {'id': 1189, 'image_count': 11}, {'id': 1190, 'image_count': 941}, {'id': 1191, 'image_count': 113}, {'id': 1192, 'image_count': 26}, {'id': 1193, 'image_count': 5}, {'id': 1194, 'image_count': 56}, {'id': 1195, 'image_count': 73}, {'id': 1196, 'image_count': 32}, {'id': 1197, 'image_count': 128}, {'id': 1198, 'image_count': 623}, {'id': 1199, 'image_count': 12}, {'id': 1200, 'image_count': 52}, {'id': 1201, 'image_count': 11}, {'id': 1202, 'image_count': 1674}, {'id': 1203, 'image_count': 81}] # noqa
# fmt: on
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
import os
import xml.etree.ElementTree as ET
from typing import List, Tuple, Union
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode
from detectron2.utils.file_io import PathManager
__all__ = ["load_voc_instances", "register_pascal_voc"]
# fmt: off
CLASS_NAMES = (
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
)
# fmt: on
def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
"""
Load Pascal VOC detection annotations to Detectron2 format.
Args:
dirname: Contain "Annotations", "ImageSets", "JPEGImages"
split (str): one of "train", "test", "val", "trainval"
class_names: list or tuple of class names
"""
with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
fileids = np.loadtxt(f, dtype=str)
# Needs to read many small annotation files. Makes sense at local
annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
dicts = []
for fileid in fileids:
anno_file = os.path.join(annotation_dirname, fileid + ".xml")
jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
with PathManager.open(anno_file) as f:
tree = ET.parse(f)
r = {
"file_name": jpeg_file,
"image_id": fileid,
"height": int(tree.findall("./size/height")[0].text),
"width": int(tree.findall("./size/width")[0].text),
}
instances = []
for obj in tree.findall("object"):
cls = obj.find("name").text
# We include "difficult" samples in training.
# Based on limited experiments, they don't hurt accuracy.
# difficult = int(obj.find("difficult").text)
# if difficult == 1:
# continue
bbox = obj.find("bndbox")
bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
# Original annotations are integers in the range [1, W or H]
# Assuming they mean 1-based pixel indices (inclusive),
# a box with annotation (xmin=1, xmax=W) covers the whole image.
# In coordinate space this is represented by (xmin=0, xmax=W)
bbox[0] -= 1.0
bbox[1] -= 1.0
instances.append(
{"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
)
r["annotations"] = instances
dicts.append(r)
return dicts
def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES):
DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names))
MetadataCatalog.get(name).set(
thing_classes=list(class_names), dirname=dirname, year=year, split=split
)
# Copyright (c) Facebook, Inc. and its affiliates.
from .coco import register_coco_instances # noqa
from .coco_panoptic import register_coco_panoptic_separated # noqa
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
"""
Common data processing utilities that are used in a
typical object detection data pipeline.
"""
import logging
import numpy as np
from typing import List, Union
import pycocotools.mask as mask_util
import torch
from PIL import Image
from detectron2.structures import (
BitMasks,
Boxes,
BoxMode,
Instances,
Keypoints,
PolygonMasks,
RotatedBoxes,
polygons_to_bitmask,
)
from detectron2.utils.file_io import PathManager
from . import transforms as T
from .catalog import MetadataCatalog
__all__ = [
"SizeMismatchError",
"convert_image_to_rgb",
"check_image_size",
"transform_proposals",
"transform_instance_annotations",
"annotations_to_instances",
"annotations_to_instances_rotated",
"build_augmentation",
"build_transform_gen",
"create_keypoint_hflip_indices",
"filter_empty_instances",
"read_image",
]
class SizeMismatchError(ValueError):
"""
When loaded image has difference width/height compared with annotation.
"""
# https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601
_M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]]
_M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]]
# https://www.exiv2.org/tags.html
_EXIF_ORIENT = 274 # exif 'Orientation' tag
def convert_PIL_to_numpy(image, format):
"""
Convert PIL image to numpy array of target format.
Args:
image (PIL.Image): a PIL image
format (str): the format of output image
Returns:
(np.ndarray): also see `read_image`
"""
if format is not None:
# PIL only supports RGB, so convert to RGB and flip channels over below
conversion_format = format
if format in ["BGR", "YUV-BT.601"]:
conversion_format = "RGB"
image = image.convert(conversion_format)
image = np.asarray(image)
# PIL squeezes out the channel dimension for "L", so make it HWC
if format == "L":
image = np.expand_dims(image, -1)
# handle formats not supported by PIL
elif format == "BGR":
# flip channels if needed
image = image[:, :, ::-1]
elif format == "YUV-BT.601":
image = image / 255.0
image = np.dot(image, np.array(_M_RGB2YUV).T)
return image
def convert_image_to_rgb(image, format):
"""
Convert an image from given format to RGB.
Args:
image (np.ndarray or Tensor): an HWC image
format (str): the format of input image, also see `read_image`
Returns:
(np.ndarray): (H,W,3) RGB image in 0-255 range, can be either float or uint8
"""
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
if format == "BGR":
image = image[:, :, [2, 1, 0]]
elif format == "YUV-BT.601":
image = np.dot(image, np.array(_M_YUV2RGB).T)
image = image * 255.0
else:
if format == "L":
image = image[:, :, 0]
image = image.astype(np.uint8)
image = np.asarray(Image.fromarray(image, mode=format).convert("RGB"))
return image
def _apply_exif_orientation(image):
"""
Applies the exif orientation correctly.
This code exists per the bug:
https://github.com/python-pillow/Pillow/issues/3973
with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
various methods, especially `tobytes`
Function based on:
https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
Args:
image (PIL.Image): a PIL image
Returns:
(PIL.Image): the PIL image with exif orientation applied, if applicable
"""
if not hasattr(image, "getexif"):
return image
try:
exif = image.getexif()
except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
exif = None
if exif is None:
return image
orientation = exif.get(_EXIF_ORIENT)
method = {
2: Image.FLIP_LEFT_RIGHT,
3: Image.ROTATE_180,
4: Image.FLIP_TOP_BOTTOM,
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
8: Image.ROTATE_90,
}.get(orientation)
if method is not None:
return image.transpose(method)
return image
def read_image(file_name, format=None):
"""
Read an image into the given format.
Will apply rotation and flipping if the image has such exif information.
Args:
file_name (str): image file path
format (str): one of the supported image modes in PIL, or "BGR" or "YUV-BT.601".
Returns:
image (np.ndarray):
an HWC image in the given format, which is 0-255, uint8 for
supported image modes in PIL or "BGR"; float (0-1 for Y) for YUV-BT.601.
"""
with PathManager.open(file_name, "rb") as f:
image = Image.open(f)
# work around this bug: https://github.com/python-pillow/Pillow/issues/3973
image = _apply_exif_orientation(image)
return convert_PIL_to_numpy(image, format)
def check_image_size(dataset_dict, image):
"""
Raise an error if the image does not match the size specified in the dict.
"""
if "width" in dataset_dict or "height" in dataset_dict:
image_wh = (image.shape[1], image.shape[0])
expected_wh = (dataset_dict["width"], dataset_dict["height"])
if not image_wh == expected_wh:
raise SizeMismatchError(
"Mismatched image shape{}, got {}, expect {}.".format(
(
" for image " + dataset_dict["file_name"]
if "file_name" in dataset_dict
else ""
),
image_wh,
expected_wh,
)
+ " Please check the width/height in your annotation."
)
# To ensure bbox always remap to original image size
if "width" not in dataset_dict:
dataset_dict["width"] = image.shape[1]
if "height" not in dataset_dict:
dataset_dict["height"] = image.shape[0]
def transform_proposals(dataset_dict, image_shape, transforms, *, proposal_topk, min_box_size=0):
"""
Apply transformations to the proposals in dataset_dict, if any.
Args:
dataset_dict (dict): a dict read from the dataset, possibly
contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode"
image_shape (tuple): height, width
transforms (TransformList):
proposal_topk (int): only keep top-K scoring proposals
min_box_size (int): proposals with either side smaller than this
threshold are removed
The input dict is modified in-place, with abovementioned keys removed. A new
key "proposals" will be added. Its value is an `Instances`
object which contains the transformed proposals in its field
"proposal_boxes" and "objectness_logits".
"""
if "proposal_boxes" in dataset_dict:
# Transform proposal boxes
boxes = transforms.apply_box(
BoxMode.convert(
dataset_dict.pop("proposal_boxes"),
dataset_dict.pop("proposal_bbox_mode"),
BoxMode.XYXY_ABS,
)
)
boxes = Boxes(boxes)
objectness_logits = torch.as_tensor(
dataset_dict.pop("proposal_objectness_logits").astype("float32")
)
boxes.clip(image_shape)
keep = boxes.nonempty(threshold=min_box_size)
boxes = boxes[keep]
objectness_logits = objectness_logits[keep]
proposals = Instances(image_shape)
proposals.proposal_boxes = boxes[:proposal_topk]
proposals.objectness_logits = objectness_logits[:proposal_topk]
dataset_dict["proposals"] = proposals
def get_bbox(annotation):
"""
Get bbox from data
Args:
annotation (dict): dict of instance annotations for a single instance.
Returns:
bbox (ndarray): x1, y1, x2, y2 coordinates
"""
# bbox is 1d (per-instance bounding box)
bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
return bbox
def transform_instance_annotations(
annotation, transforms, image_size, *, keypoint_hflip_indices=None
):
"""
Apply transforms to box, segmentation and keypoints annotations of a single instance.
It will use `transforms.apply_box` for the box, and
`transforms.apply_coords` for segmentation polygons & keypoints.
If you need anything more specially designed for each data structure,
you'll need to implement your own version of this function or the transforms.
Args:
annotation (dict): dict of instance annotations for a single instance.
It will be modified in-place.
transforms (TransformList or list[Transform]):
image_size (tuple): the height, width of the transformed image
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
Returns:
dict:
the same input dict with fields "bbox", "segmentation", "keypoints"
transformed according to `transforms`.
The "bbox_mode" field will be set to XYXY_ABS.
"""
if isinstance(transforms, (tuple, list)):
transforms = T.TransformList(transforms)
# bbox is 1d (per-instance bounding box)
bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
# clip transformed bbox to image size
bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0)
annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1])
annotation["bbox_mode"] = BoxMode.XYXY_ABS
if "segmentation" in annotation:
# each instance contains 1 or more polygons
segm = annotation["segmentation"]
if isinstance(segm, list):
# polygons
polygons = [np.asarray(p).reshape(-1, 2) for p in segm]
annotation["segmentation"] = [
p.reshape(-1) for p in transforms.apply_polygons(polygons)
]
elif isinstance(segm, dict):
# RLE
mask = mask_util.decode(segm)
mask = transforms.apply_segmentation(mask)
assert tuple(mask.shape[:2]) == image_size
annotation["segmentation"] = mask
else:
raise ValueError(
"Cannot transform segmentation of type '{}'!"
"Supported types are: polygons as list[list[float] or ndarray],"
" COCO-style RLE as a dict.".format(type(segm))
)
if "keypoints" in annotation:
keypoints = transform_keypoint_annotations(
annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
)
annotation["keypoints"] = keypoints
return annotation
def transform_keypoint_annotations(keypoints, transforms, image_size, keypoint_hflip_indices=None):
"""
Transform keypoint annotations of an image.
If a keypoint is transformed out of image boundary, it will be marked "unlabeled" (visibility=0)
Args:
keypoints (list[float]): Nx3 float in Detectron2's Dataset format.
Each point is represented by (x, y, visibility).
transforms (TransformList):
image_size (tuple): the height, width of the transformed image
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
When `transforms` includes horizontal flip, will use the index
mapping to flip keypoints.
"""
# (N*3,) -> (N, 3)
keypoints = np.asarray(keypoints, dtype="float64").reshape(-1, 3)
keypoints_xy = transforms.apply_coords(keypoints[:, :2])
# Set all out-of-boundary points to "unlabeled"
inside = (keypoints_xy >= np.array([0, 0])) & (keypoints_xy <= np.array(image_size[::-1]))
inside = inside.all(axis=1)
keypoints[:, :2] = keypoints_xy
keypoints[:, 2][~inside] = 0
# This assumes that HorizFlipTransform is the only one that does flip
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
# Alternative way: check if probe points was horizontally flipped.
# probe = np.asarray([[0.0, 0.0], [image_width, 0.0]])
# probe_aug = transforms.apply_coords(probe.copy())
# do_hflip = np.sign(probe[1][0] - probe[0][0]) != np.sign(probe_aug[1][0] - probe_aug[0][0]) # noqa
# If flipped, swap each keypoint with its opposite-handed equivalent
if do_hflip:
if keypoint_hflip_indices is None:
raise ValueError("Cannot flip keypoints without providing flip indices!")
if len(keypoints) != len(keypoint_hflip_indices):
raise ValueError(
"Keypoint data has {} points, but metadata "
"contains {} points!".format(len(keypoints), len(keypoint_hflip_indices))
)
keypoints = keypoints[np.asarray(keypoint_hflip_indices, dtype=np.int32), :]
# Maintain COCO convention that if visibility == 0 (unlabeled), then x, y = 0
keypoints[keypoints[:, 2] == 0] = 0
return keypoints
def annotations_to_instances(annos, image_size, mask_format="polygon"):
"""
Create an :class:`Instances` object used by the models,
from instance annotations in the dataset dict.
Args:
annos (list[dict]): a list of instance annotations in one image, each
element for one instance.
image_size (tuple): height, width
Returns:
Instances:
It will contain fields "gt_boxes", "gt_classes",
"gt_masks", "gt_keypoints", if they can be obtained from `annos`.
This is the format that builtin models expect.
"""
boxes = (
np.stack(
[BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
)
if len(annos)
else np.zeros((0, 4))
)
target = Instances(image_size)
target.gt_boxes = Boxes(boxes)
classes = [int(obj["category_id"]) for obj in annos]
classes = torch.tensor(classes, dtype=torch.int64)
target.gt_classes = classes
if len(annos) and "segmentation" in annos[0]:
segms = [obj["segmentation"] for obj in annos]
if mask_format == "polygon":
try:
masks = PolygonMasks(segms)
except ValueError as e:
raise ValueError(
"Failed to use mask_format=='polygon' from the given annotations!"
) from e
else:
assert mask_format == "bitmask", mask_format
masks = []
for segm in segms:
if isinstance(segm, list):
# polygon
masks.append(polygons_to_bitmask(segm, *image_size))
elif isinstance(segm, dict):
# COCO RLE
masks.append(mask_util.decode(segm))
elif isinstance(segm, np.ndarray):
assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
segm.ndim
)
# mask array
masks.append(segm)
else:
raise ValueError(
"Cannot convert segmentation of type '{}' to BitMasks!"
"Supported types are: polygons as list[list[float] or ndarray],"
" COCO-style RLE as a dict, or a binary segmentation mask "
" in a 2D numpy array of shape HxW.".format(type(segm))
)
# torch.from_numpy does not support array with negative stride.
masks = BitMasks(
torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
)
target.gt_masks = masks
if len(annos) and "keypoints" in annos[0]:
kpts = [obj.get("keypoints", []) for obj in annos]
target.gt_keypoints = Keypoints(kpts)
return target
def annotations_to_instances_rotated(annos, image_size):
"""
Create an :class:`Instances` object used by the models,
from instance annotations in the dataset dict.
Compared to `annotations_to_instances`, this function is for rotated boxes only
Args:
annos (list[dict]): a list of instance annotations in one image, each
element for one instance.
image_size (tuple): height, width
Returns:
Instances:
Containing fields "gt_boxes", "gt_classes",
if they can be obtained from `annos`.
This is the format that builtin models expect.
"""
boxes = [obj["bbox"] for obj in annos]
target = Instances(image_size)
boxes = target.gt_boxes = RotatedBoxes(boxes)
boxes.clip(image_size)
classes = [obj["category_id"] for obj in annos]
classes = torch.tensor(classes, dtype=torch.int64)
target.gt_classes = classes
return target
def filter_empty_instances(
instances, by_box=True, by_mask=True, box_threshold=1e-5, return_mask=False
):
"""
Filter out empty instances in an `Instances` object.
Args:
instances (Instances):
by_box (bool): whether to filter out instances with empty boxes
by_mask (bool): whether to filter out instances with empty masks
box_threshold (float): minimum width and height to be considered non-empty
return_mask (bool): whether to return boolean mask of filtered instances
Returns:
Instances: the filtered instances.
tensor[bool], optional: boolean mask of filtered instances
"""
assert by_box or by_mask
r = []
if by_box:
r.append(instances.gt_boxes.nonempty(threshold=box_threshold))
if instances.has("gt_masks") and by_mask:
r.append(instances.gt_masks.nonempty())
# TODO: can also filter visible keypoints
if not r:
return instances
m = r[0]
for x in r[1:]:
m = m & x
if return_mask:
return instances[m], m
return instances[m]
def create_keypoint_hflip_indices(dataset_names: Union[str, List[str]]) -> List[int]:
"""
Args:
dataset_names: list of dataset names
Returns:
list[int]: a list of size=#keypoints, storing the
horizontally-flipped keypoint indices.
"""
if isinstance(dataset_names, str):
dataset_names = [dataset_names]
check_metadata_consistency("keypoint_names", dataset_names)
check_metadata_consistency("keypoint_flip_map", dataset_names)
meta = MetadataCatalog.get(dataset_names[0])
names = meta.keypoint_names
# TODO flip -> hflip
flip_map = dict(meta.keypoint_flip_map)
flip_map.update({v: k for k, v in flip_map.items()})
flipped_names = [i if i not in flip_map else flip_map[i] for i in names]
flip_indices = [names.index(i) for i in flipped_names]
return flip_indices
def get_fed_loss_cls_weights(dataset_names: Union[str, List[str]], freq_weight_power=1.0):
"""
Get frequency weight for each class sorted by class id.
We now calcualte freqency weight using image_count to the power freq_weight_power.
Args:
dataset_names: list of dataset names
freq_weight_power: power value
"""
if isinstance(dataset_names, str):
dataset_names = [dataset_names]
check_metadata_consistency("class_image_count", dataset_names)
meta = MetadataCatalog.get(dataset_names[0])
class_freq_meta = meta.class_image_count
class_freq = torch.tensor(
[c["image_count"] for c in sorted(class_freq_meta, key=lambda x: x["id"])]
)
class_freq_weight = class_freq.float() ** freq_weight_power
return class_freq_weight
def gen_crop_transform_with_instance(crop_size, image_size, instance):
"""
Generate a CropTransform so that the cropping region contains
the center of the given instance.
Args:
crop_size (tuple): h, w in pixels
image_size (tuple): h, w
instance (dict): an annotation dict of one instance, in Detectron2's
dataset format.
"""
crop_size = np.asarray(crop_size, dtype=np.int32)
bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS)
center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5
assert (
image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1]
), "The annotation bounding box is outside of the image!"
assert (
image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1]
), "Crop size is larger than image size!"
min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0)
max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0)
max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32))
y0 = np.random.randint(min_yx[0], max_yx[0] + 1)
x0 = np.random.randint(min_yx[1], max_yx[1] + 1)
return T.CropTransform(x0, y0, crop_size[1], crop_size[0])
def check_metadata_consistency(key, dataset_names):
"""
Check that the datasets have consistent metadata.
Args:
key (str): a metadata key
dataset_names (list[str]): a list of dataset names
Raises:
AttributeError: if the key does not exist in the metadata
ValueError: if the given datasets do not have the same metadata values defined by key
"""
if len(dataset_names) == 0:
return
logger = logging.getLogger(__name__)
entries_per_dataset = [getattr(MetadataCatalog.get(d), key) for d in dataset_names]
for idx, entry in enumerate(entries_per_dataset):
if entry != entries_per_dataset[0]:
logger.error(
"Metadata '{}' for dataset '{}' is '{}'".format(key, dataset_names[idx], str(entry))
)
logger.error(
"Metadata '{}' for dataset '{}' is '{}'".format(
key, dataset_names[0], str(entries_per_dataset[0])
)
)
raise ValueError("Datasets have different metadata '{}'!".format(key))
def build_augmentation(cfg, is_train):
"""
Create a list of default :class:`Augmentation` from config.
Now it includes resizing and flipping.
Returns:
list[Augmentation]
"""
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
else:
min_size = cfg.INPUT.MIN_SIZE_TEST
max_size = cfg.INPUT.MAX_SIZE_TEST
sample_style = "choice"
augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
if is_train and cfg.INPUT.RANDOM_FLIP != "none":
augmentation.append(
T.RandomFlip(
horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
)
)
return augmentation
build_transform_gen = build_augmentation
"""
Alias for backward-compatibility.
"""
# Copyright (c) Facebook, Inc. and its affiliates.
from .distributed_sampler import (
InferenceSampler,
RandomSubsetTrainingSampler,
RepeatFactorTrainingSampler,
TrainingSampler,
)
from .grouped_batch_sampler import GroupedBatchSampler
__all__ = [
"GroupedBatchSampler",
"TrainingSampler",
"RandomSubsetTrainingSampler",
"InferenceSampler",
"RepeatFactorTrainingSampler",
]
# Copyright (c) Facebook, Inc. and its affiliates.
import itertools
import logging
import math
from collections import defaultdict
from typing import Optional
import torch
from torch.utils.data.sampler import Sampler
from detectron2.utils import comm
logger = logging.getLogger(__name__)
class TrainingSampler(Sampler):
"""
In training, we only care about the "infinite stream" of training data.
So this sampler produces an infinite stream of indices and
all workers cooperate to correctly shuffle the indices and sample different indices.
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
where `indices` is an infinite stream of indices consisting of
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
or `range(size) + range(size) + ...` (if shuffle is False)
Note that this sampler does not shard based on pytorch DataLoader worker id.
A sampler passed to pytorch DataLoader is used only with map-style dataset
and will not be executed inside workers.
But if this sampler is used in a way that it gets execute inside a dataloader
worker, then extra work needs to be done to shard its outputs based on worker id.
This is required so that workers don't produce identical data.
:class:`ToIterableDataset` implements this logic.
This note is true for all samplers in detectron2.
"""
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
shuffle (bool): whether to shuffle the indices or not
seed (int): the initial seed of the shuffle. Must be the same
across all workers. If None, will use a random seed shared
among workers (require synchronization among all workers).
"""
if not isinstance(size, int):
raise TypeError(f"TrainingSampler(size=) expects an int. Got type {type(size)}.")
if size <= 0:
raise ValueError(f"TrainingSampler(size=) expects a positive int. Got {size}.")
self._size = size
self._shuffle = shuffle
if seed is None:
seed = comm.shared_random_seed()
self._seed = int(seed)
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
def __iter__(self):
start = self._rank
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
def _infinite_indices(self):
g = torch.Generator()
if self._seed is not None:
g.manual_seed(self._seed)
while True:
if self._shuffle:
yield from torch.randperm(self._size, generator=g).tolist()
else:
yield from torch.arange(self._size).tolist()
class RandomSubsetTrainingSampler(TrainingSampler):
"""
Similar to TrainingSampler, but only sample a random subset of indices.
This is useful when you want to estimate the accuracy vs data-number curves by
training the model with different subset_ratio.
"""
def __init__(
self,
size: int,
subset_ratio: float,
shuffle: bool = True,
seed_shuffle: Optional[int] = None,
seed_subset: Optional[int] = None,
):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
subset_ratio (float): the ratio of subset data to sample from the underlying dataset
shuffle (bool): whether to shuffle the indices or not
seed_shuffle (int): the initial seed of the shuffle. Must be the same
across all workers. If None, will use a random seed shared
among workers (require synchronization among all workers).
seed_subset (int): the seed to randomize the subset to be sampled.
Must be the same across all workers. If None, will use a random seed shared
among workers (require synchronization among all workers).
"""
super().__init__(size=size, shuffle=shuffle, seed=seed_shuffle)
assert 0.0 < subset_ratio <= 1.0
self._size_subset = int(size * subset_ratio)
assert self._size_subset > 0
if seed_subset is None:
seed_subset = comm.shared_random_seed()
self._seed_subset = int(seed_subset)
# randomly generate the subset indexes to be sampled from
g = torch.Generator()
g.manual_seed(self._seed_subset)
indexes_randperm = torch.randperm(self._size, generator=g)
self._indexes_subset = indexes_randperm[: self._size_subset]
logger.info("Using RandomSubsetTrainingSampler......")
logger.info(f"Randomly sample {self._size_subset} data from the original {self._size} data")
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed) # self._seed equals seed_shuffle from __init__()
while True:
if self._shuffle:
# generate a random permutation to shuffle self._indexes_subset
randperm = torch.randperm(self._size_subset, generator=g)
yield from self._indexes_subset[randperm].tolist()
else:
yield from self._indexes_subset.tolist()
class RepeatFactorTrainingSampler(Sampler):
"""
Similar to TrainingSampler, but a sample may appear more times than others based
on its "repeat factor". This is suitable for training on class imbalanced datasets like LVIS.
"""
def __init__(self, repeat_factors, *, shuffle=True, seed=None):
"""
Args:
repeat_factors (Tensor): a float vector, the repeat factor for each indice. When it's
full of ones, it is equivalent to ``TrainingSampler(len(repeat_factors), ...)``.
shuffle (bool): whether to shuffle the indices or not
seed (int): the initial seed of the shuffle. Must be the same
across all workers. If None, will use a random seed shared
among workers (require synchronization among all workers).
"""
self._shuffle = shuffle
if seed is None:
seed = comm.shared_random_seed()
self._seed = int(seed)
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
# Split into whole number (_int_part) and fractional (_frac_part) parts.
self._int_part = torch.trunc(repeat_factors)
self._frac_part = repeat_factors - self._int_part
@staticmethod
def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh, sqrt=True):
"""
Compute (fractional) per-image repeat factors based on category frequency.
The repeat factor for an image is a function of the frequency of the rarest
category labeled in that image. The "frequency of category c" in [0, 1] is defined
as the fraction of images in the training set (without repeats) in which category c
appears.
See :paper:`lvis` (>= v2) Appendix B.2.
Args:
dataset_dicts (list[dict]): annotations in Detectron2 dataset format.
repeat_thresh (float): frequency threshold below which data is repeated.
If the frequency is half of `repeat_thresh`, the image will be
repeated twice.
sqrt (bool): if True, apply :func:`math.sqrt` to the repeat factor.
Returns:
torch.Tensor:
the i-th element is the repeat factor for the dataset image at index i.
"""
# 1. For each category c, compute the fraction of images that contain it: f(c)
category_freq = defaultdict(int)
for dataset_dict in dataset_dicts: # For each image (without repeats)
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
for cat_id in cat_ids:
category_freq[cat_id] += 1
num_images = len(dataset_dicts)
for k, v in category_freq.items():
category_freq[k] = v / num_images
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t / f(c)))
category_rep = {
cat_id: max(
1.0,
(math.sqrt(repeat_thresh / cat_freq) if sqrt else (repeat_thresh / cat_freq)),
)
for cat_id, cat_freq in category_freq.items()
}
for cat_id in sorted(category_rep.keys()):
logger.info(
f"Cat ID {cat_id}: freq={category_freq[cat_id]:.2f}, rep={category_rep[cat_id]:.2f}"
)
# 3. For each image I, compute the image-level repeat factor:
# r(I) = max_{c in I} r(c)
rep_factors = []
for dataset_dict in dataset_dicts:
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
rep_factors.append(rep_factor)
return torch.tensor(rep_factors, dtype=torch.float32)
def _get_epoch_indices(self, generator):
"""
Create a list of dataset indices (with repeats) to use for one epoch.
Args:
generator (torch.Generator): pseudo random number generator used for
stochastic rounding.
Returns:
torch.Tensor: list of dataset indices to use in one epoch. Each index
is repeated based on its calculated repeat factor.
"""
# Since repeat factors are fractional, we use stochastic rounding so
# that the target repeat factor is achieved in expectation over the
# course of training
rands = torch.rand(len(self._frac_part), generator=generator)
rep_factors = self._int_part + (rands < self._frac_part).float()
# Construct a list of indices in which we repeat images as specified
indices = []
for dataset_index, rep_factor in enumerate(rep_factors):
indices.extend([dataset_index] * int(rep_factor.item()))
return torch.tensor(indices, dtype=torch.int64)
def __iter__(self):
start = self._rank
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
while True:
# Sample indices with repeats determined by stochastic rounding; each
# "epoch" may have a slightly different size due to the rounding.
indices = self._get_epoch_indices(g)
if self._shuffle:
randperm = torch.randperm(len(indices), generator=g)
yield from indices[randperm].tolist()
else:
yield from indices.tolist()
class InferenceSampler(Sampler):
"""
Produce indices for inference across all workers.
Inference needs to run on the __exact__ set of samples,
therefore when the total number of samples is not divisible by the number of workers,
this sampler produces different number of samples on different workers.
"""
def __init__(self, size: int):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
"""
self._size = size
assert size > 0
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[: rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
from torch.utils.data.sampler import BatchSampler, Sampler
class GroupedBatchSampler(BatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices.
It enforces that the batch only contain elements from the same group.
It also tries to provide mini-batches which follows an ordering which is
as close as possible to the ordering from the original sampler.
"""
def __init__(self, sampler, group_ids, batch_size):
"""
Args:
sampler (Sampler): Base sampler.
group_ids (list[int]): If the sampler produces indices in range [0, N),
`group_ids` must be a list of `N` ints which contains the group id of each sample.
The group ids must be a set of integers in the range [0, num_groups).
batch_size (int): Size of mini-batch.
"""
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = np.asarray(group_ids)
assert self.group_ids.ndim == 1
self.batch_size = batch_size
groups = np.unique(self.group_ids).tolist()
# buffer the indices of each group until batch size is reached
self.buffer_per_group = {k: [] for k in groups}
def __iter__(self):
for idx in self.sampler:
group_id = self.group_ids[idx]
group_buffer = self.buffer_per_group[group_id]
group_buffer.append(idx)
if len(group_buffer) == self.batch_size:
yield group_buffer[:] # yield a copy of the list
del group_buffer[:]
def __len__(self):
raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.")
# Copyright (c) Facebook, Inc. and its affiliates.
from fvcore.transforms.transform import Transform, TransformList # order them first
from fvcore.transforms.transform import *
from .transform import *
from .augmentation import *
from .augmentation_impl import *
__all__ = [k for k in globals().keys() if not k.startswith("_")]
from detectron2.utils.env import fixup_module_metadata
fixup_module_metadata(__name__, globals(), __all__)
del fixup_module_metadata
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
import inspect
import numpy as np
import pprint
from typing import Any, List, Optional, Tuple, Union
from fvcore.transforms.transform import Transform, TransformList
"""
See "Data Augmentation" tutorial for an overview of the system:
https://detectron2.readthedocs.io/tutorials/augmentation.html
"""
__all__ = [
"Augmentation",
"AugmentationList",
"AugInput",
"TransformGen",
"apply_transform_gens",
"StandardAugInput",
"apply_augmentations",
]
def _check_img_dtype(img):
assert isinstance(img, np.ndarray), "[Augmentation] Needs an numpy array, but got a {}!".format(
type(img)
)
assert not isinstance(img.dtype, np.integer) or (
img.dtype == np.uint8
), "[Augmentation] Got image of type {}, use uint8 or floating points instead!".format(
img.dtype
)
assert img.ndim in [2, 3], img.ndim
def _get_aug_input_args(aug, aug_input) -> List[Any]:
"""
Get the arguments to be passed to ``aug.get_transform`` from the input ``aug_input``.
"""
if aug.input_args is None:
# Decide what attributes are needed automatically
prms = list(inspect.signature(aug.get_transform).parameters.items())
# The default behavior is: if there is one parameter, then its "image"
# (work automatically for majority of use cases, and also avoid BC breaking),
# Otherwise, use the argument names.
if len(prms) == 1:
names = ("image",)
else:
names = []
for name, prm in prms:
if prm.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
raise TypeError(
f""" \
The default implementation of `{type(aug)}.__call__` does not allow \
`{type(aug)}.get_transform` to use variable-length arguments (*args, **kwargs)! \
If arguments are unknown, reimplement `__call__` instead. \
"""
)
names.append(name)
aug.input_args = tuple(names)
args = []
for f in aug.input_args:
try:
args.append(getattr(aug_input, f))
except AttributeError as e:
raise AttributeError(
f"{type(aug)}.get_transform needs input attribute '{f}', "
f"but it is not an attribute of {type(aug_input)}!"
) from e
return args
class Augmentation:
"""
Augmentation defines (often random) policies/strategies to generate :class:`Transform`
from data. It is often used for pre-processing of input data.
A "policy" that generates a :class:`Transform` may, in the most general case,
need arbitrary information from input data in order to determine what transforms
to apply. Therefore, each :class:`Augmentation` instance defines the arguments
needed by its :meth:`get_transform` method. When called with the positional arguments,
the :meth:`get_transform` method executes the policy.
Note that :class:`Augmentation` defines the policies to create a :class:`Transform`,
but not how to execute the actual transform operations to those data.
Its :meth:`__call__` method will use :meth:`AugInput.transform` to execute the transform.
The returned `Transform` object is meant to describe deterministic transformation, which means
it can be re-applied on associated data, e.g. the geometry of an image and its segmentation
masks need to be transformed together.
(If such re-application is not needed, then determinism is not a crucial requirement.)
"""
input_args: Optional[Tuple[str]] = None
"""
Stores the attribute names needed by :meth:`get_transform`, e.g. ``("image", "sem_seg")``.
By default, it is just a tuple of argument names in :meth:`self.get_transform`, which often only
contain "image". As long as the argument name convention is followed, there is no need for
users to touch this attribute.
"""
def _init(self, params=None):
if params:
for k, v in params.items():
if k != "self" and not k.startswith("_"):
setattr(self, k, v)
def get_transform(self, *args) -> Transform:
"""
Execute the policy based on input data, and decide what transform to apply to inputs.
Args:
args: Any fixed-length positional arguments. By default, the name of the arguments
should exist in the :class:`AugInput` to be used.
Returns:
Transform: Returns the deterministic transform to apply to the input.
Examples:
::
class MyAug:
# if a policy needs to know both image and semantic segmentation
def get_transform(image, sem_seg) -> T.Transform:
pass
tfm: Transform = MyAug().get_transform(image, sem_seg)
new_image = tfm.apply_image(image)
Notes:
Users can freely use arbitrary new argument names in custom
:meth:`get_transform` method, as long as they are available in the
input data. In detectron2 we use the following convention:
* image: (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
floating point in range [0, 1] or [0, 255].
* boxes: (N,4) ndarray of float32. It represents the instance bounding boxes
of N instances. Each is in XYXY format in unit of absolute coordinates.
* sem_seg: (H,W) ndarray of type uint8. Each element is an integer label of pixel.
We do not specify convention for other types and do not include builtin
:class:`Augmentation` that uses other types in detectron2.
"""
raise NotImplementedError
def __call__(self, aug_input) -> Transform:
"""
Augment the given `aug_input` **in-place**, and return the transform that's used.
This method will be called to apply the augmentation. In most augmentation, it
is enough to use the default implementation, which calls :meth:`get_transform`
using the inputs. But a subclass can overwrite it to have more complicated logic.
Args:
aug_input (AugInput): an object that has attributes needed by this augmentation
(defined by ``self.get_transform``). Its ``transform`` method will be called
to in-place transform it.
Returns:
Transform: the transform that is applied on the input.
"""
args = _get_aug_input_args(self, aug_input)
tfm = self.get_transform(*args)
assert isinstance(tfm, (Transform, TransformList)), (
f"{type(self)}.get_transform must return an instance of Transform! "
f"Got {type(tfm)} instead."
)
aug_input.transform(tfm)
return tfm
def _rand_range(self, low=1.0, high=None, size=None):
"""
Uniform float random number between low and high.
"""
if high is None:
low, high = 0, low
if size is None:
size = []
return np.random.uniform(low, high, size)
def __repr__(self):
"""
Produce something like:
"MyAugmentation(field1={self.field1}, field2={self.field2})"
"""
try:
sig = inspect.signature(self.__init__)
classname = type(self).__name__
argstr = []
for name, param in sig.parameters.items():
assert (
param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD
), "The default __repr__ doesn't support *args or **kwargs"
assert hasattr(self, name), (
"Attribute {} not found! "
"Default __repr__ only works if attributes match the constructor.".format(name)
)
attr = getattr(self, name)
default = param.default
if default is attr:
continue
attr_str = pprint.pformat(attr)
if "\n" in attr_str:
# don't show it if pformat decides to use >1 lines
attr_str = "..."
argstr.append("{}={}".format(name, attr_str))
return "{}({})".format(classname, ", ".join(argstr))
except AssertionError:
return super().__repr__()
__str__ = __repr__
class _TransformToAug(Augmentation):
def __init__(self, tfm: Transform):
self.tfm = tfm
def get_transform(self, *args):
return self.tfm
def __repr__(self):
return repr(self.tfm)
__str__ = __repr__
def _transform_to_aug(tfm_or_aug):
"""
Wrap Transform into Augmentation.
Private, used internally to implement augmentations.
"""
assert isinstance(tfm_or_aug, (Transform, Augmentation)), tfm_or_aug
if isinstance(tfm_or_aug, Augmentation):
return tfm_or_aug
else:
return _TransformToAug(tfm_or_aug)
class AugmentationList(Augmentation):
"""
Apply a sequence of augmentations.
It has ``__call__`` method to apply the augmentations.
Note that :meth:`get_transform` method is impossible (will throw error if called)
for :class:`AugmentationList`, because in order to apply a sequence of augmentations,
the kth augmentation must be applied first, to provide inputs needed by the (k+1)th
augmentation.
"""
def __init__(self, augs):
"""
Args:
augs (list[Augmentation or Transform]):
"""
super().__init__()
self.augs = [_transform_to_aug(x) for x in augs]
def __call__(self, aug_input) -> TransformList:
tfms = []
for x in self.augs:
tfm = x(aug_input)
tfms.append(tfm)
return TransformList(tfms)
def __repr__(self):
msgs = [str(x) for x in self.augs]
return "AugmentationList[{}]".format(", ".join(msgs))
__str__ = __repr__
class AugInput:
"""
Input that can be used with :meth:`Augmentation.__call__`.
This is a standard implementation for the majority of use cases.
This class provides the standard attributes **"image", "boxes", "sem_seg"**
defined in :meth:`__init__` and they may be needed by different augmentations.
Most augmentation policies do not need attributes beyond these three.
After applying augmentations to these attributes (using :meth:`AugInput.transform`),
the returned transforms can then be used to transform other data structures that users have.
Examples:
::
input = AugInput(image, boxes=boxes)
tfms = augmentation(input)
transformed_image = input.image
transformed_boxes = input.boxes
transformed_other_data = tfms.apply_other(other_data)
An extended project that works with new data types may implement augmentation policies
that need other inputs. An algorithm may need to transform inputs in a way different
from the standard approach defined in this class. In those rare situations, users can
implement a class similar to this class, that satify the following condition:
* The input must provide access to these data in the form of attribute access
(``getattr``). For example, if an :class:`Augmentation` to be applied needs "image"
and "sem_seg" arguments, its input must have the attribute "image" and "sem_seg".
* The input must have a ``transform(tfm: Transform) -> None`` method which
in-place transforms all its attributes.
"""
# TODO maybe should support more builtin data types here
def __init__(
self,
image: np.ndarray,
*,
boxes: Optional[np.ndarray] = None,
sem_seg: Optional[np.ndarray] = None,
):
"""
Args:
image (ndarray): (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
floating point in range [0, 1] or [0, 255]. The meaning of C is up
to users.
boxes (ndarray or None): Nx4 float32 boxes in XYXY_ABS mode
sem_seg (ndarray or None): HxW uint8 semantic segmentation mask. Each element
is an integer label of pixel.
"""
_check_img_dtype(image)
self.image = image
self.boxes = boxes
self.sem_seg = sem_seg
def transform(self, tfm: Transform) -> None:
"""
In-place transform all attributes of this class.
By "in-place", it means after calling this method, accessing an attribute such
as ``self.image`` will return transformed data.
"""
self.image = tfm.apply_image(self.image)
if self.boxes is not None:
self.boxes = tfm.apply_box(self.boxes)
if self.sem_seg is not None:
self.sem_seg = tfm.apply_segmentation(self.sem_seg)
def apply_augmentations(
self, augmentations: List[Union[Augmentation, Transform]]
) -> TransformList:
"""
Equivalent of ``AugmentationList(augmentations)(self)``
"""
return AugmentationList(augmentations)(self)
def apply_augmentations(augmentations: List[Union[Transform, Augmentation]], inputs):
"""
Use ``T.AugmentationList(augmentations)(inputs)`` instead.
"""
if isinstance(inputs, np.ndarray):
# handle the common case of image-only Augmentation, also for backward compatibility
image_only = True
inputs = AugInput(inputs)
else:
image_only = False
tfms = inputs.apply_augmentations(augmentations)
return inputs.image if image_only else inputs, tfms
apply_transform_gens = apply_augmentations
"""
Alias for backward-compatibility.
"""
TransformGen = Augmentation
"""
Alias for Augmentation, since it is something that generates :class:`Transform`s
"""
StandardAugInput = AugInput
"""
Alias for compatibility. It's not worth the complexity to have two classes.
"""
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
"""
Implement many useful :class:`Augmentation`.
"""
import numpy as np
import sys
from numpy import random
from typing import Tuple
import torch
from fvcore.transforms.transform import (
BlendTransform,
CropTransform,
HFlipTransform,
NoOpTransform,
PadTransform,
Transform,
TransformList,
VFlipTransform,
)
from PIL import Image
from detectron2.structures import Boxes, pairwise_iou
from .augmentation import Augmentation, _transform_to_aug
from .transform import ExtentTransform, ResizeTransform, RotationTransform
__all__ = [
"FixedSizeCrop",
"RandomApply",
"RandomBrightness",
"RandomContrast",
"RandomCrop",
"RandomExtent",
"RandomFlip",
"RandomSaturation",
"RandomLighting",
"RandomRotation",
"Resize",
"ResizeScale",
"ResizeShortestEdge",
"RandomCrop_CategoryAreaConstraint",
"RandomResize",
"MinIoURandomCrop",
]
class RandomApply(Augmentation):
"""
Randomly apply an augmentation with a given probability.
"""
def __init__(self, tfm_or_aug, prob=0.5):
"""
Args:
tfm_or_aug (Transform, Augmentation): the transform or augmentation
to be applied. It can either be a `Transform` or `Augmentation`
instance.
prob (float): probability between 0.0 and 1.0 that
the wrapper transformation is applied
"""
super().__init__()
self.aug = _transform_to_aug(tfm_or_aug)
assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
self.prob = prob
def get_transform(self, *args):
do = self._rand_range() < self.prob
if do:
return self.aug.get_transform(*args)
else:
return NoOpTransform()
def __call__(self, aug_input):
do = self._rand_range() < self.prob
if do:
return self.aug(aug_input)
else:
return NoOpTransform()
class RandomFlip(Augmentation):
"""
Flip the image horizontally or vertically with the given probability.
"""
def __init__(self, prob=0.5, *, horizontal=True, vertical=False):
"""
Args:
prob (float): probability of flip.
horizontal (boolean): whether to apply horizontal flipping
vertical (boolean): whether to apply vertical flipping
"""
super().__init__()
if horizontal and vertical:
raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
if not horizontal and not vertical:
raise ValueError("At least one of horiz or vert has to be True!")
self._init(locals())
def get_transform(self, image):
h, w = image.shape[:2]
do = self._rand_range() < self.prob
if do:
if self.horizontal:
return HFlipTransform(w)
elif self.vertical:
return VFlipTransform(h)
else:
return NoOpTransform()
class Resize(Augmentation):
"""Resize image to a fixed target size"""
def __init__(self, shape, interp=Image.BILINEAR):
"""
Args:
shape: (h, w) tuple or a int
interp: PIL interpolation method
"""
if isinstance(shape, int):
shape = (shape, shape)
shape = tuple(shape)
self._init(locals())
def get_transform(self, image):
return ResizeTransform(
image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp
)
class ResizeShortestEdge(Augmentation):
"""
Resize the image while keeping the aspect ratio unchanged.
It attempts to scale the shorter edge to the given `short_edge_length`,
as long as the longer edge does not exceed `max_size`.
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
"""
@torch.jit.unused
def __init__(
self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
):
"""
Args:
short_edge_length (list[int]): If ``sample_style=="range"``,
a [min, max] interval from which to sample the shortest edge length.
If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
max_size (int): maximum allowed longest edge length.
sample_style (str): either "range" or "choice".
"""
super().__init__()
assert sample_style in ["range", "choice"], sample_style
self.is_range = sample_style == "range"
if isinstance(short_edge_length, int):
short_edge_length = (short_edge_length, short_edge_length)
if self.is_range:
assert len(short_edge_length) == 2, (
"short_edge_length must be two values using 'range' sample style."
f" Got {short_edge_length}!"
)
self._init(locals())
@torch.jit.unused
def get_transform(self, image):
h, w = image.shape[:2]
if self.is_range:
size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
else:
size = np.random.choice(self.short_edge_length)
if size == 0:
return NoOpTransform()
newh, neww = ResizeShortestEdge.get_output_shape(h, w, size, self.max_size)
return ResizeTransform(h, w, newh, neww, self.interp)
@staticmethod
def get_output_shape(
oldh: int, oldw: int, short_edge_length: int, max_size: int
) -> Tuple[int, int]:
"""
Compute the output size given input size and target short edge length.
"""
h, w = oldh, oldw
size = short_edge_length * 1.0
scale = size / min(h, w)
if h < w:
newh, neww = size, scale * w
else:
newh, neww = scale * h, size
if max(newh, neww) > max_size:
scale = max_size * 1.0 / max(newh, neww)
newh = newh * scale
neww = neww * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
class ResizeScale(Augmentation):
"""
Takes target size as input and randomly scales the given target size between `min_scale`
and `max_scale`. It then scales the input image such that it fits inside the scaled target
box, keeping the aspect ratio constant.
This implements the resize part of the Google's 'resize_and_crop' data augmentation:
https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127
"""
def __init__(
self,
min_scale: float,
max_scale: float,
target_height: int,
target_width: int,
interp: int = Image.BILINEAR,
):
"""
Args:
min_scale: minimum image scale range.
max_scale: maximum image scale range.
target_height: target image height.
target_width: target image width.
interp: image interpolation method.
"""
super().__init__()
self._init(locals())
def _get_resize(self, image: np.ndarray, scale: float) -> Transform:
input_size = image.shape[:2]
# Compute new target size given a scale.
target_size = (self.target_height, self.target_width)
target_scale_size = np.multiply(target_size, scale)
# Compute actual rescaling applied to input image and output size.
output_scale = np.minimum(
target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1]
)
output_size = np.round(np.multiply(input_size, output_scale)).astype(int)
return ResizeTransform(
input_size[0], input_size[1], int(output_size[0]), int(output_size[1]), self.interp
)
def get_transform(self, image: np.ndarray) -> Transform:
random_scale = np.random.uniform(self.min_scale, self.max_scale)
return self._get_resize(image, random_scale)
class RandomRotation(Augmentation):
"""
This method returns a copy of this image, rotated the given
number of degrees counter clockwise around the given center.
"""
def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None):
"""
Args:
angle (list[float]): If ``sample_style=="range"``,
a [min, max] interval from which to sample the angle (in degrees).
If ``sample_style=="choice"``, a list of angles to sample from
expand (bool): choose if the image should be resized to fit the whole
rotated image (default), or simply cropped
center (list[[float, float]]): If ``sample_style=="range"``,
a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center,
[0, 0] being the top left of the image and [1, 1] the bottom right.
If ``sample_style=="choice"``, a list of centers to sample from
Default: None, which means that the center of rotation is the center of the image
center has no effect if expand=True because it only affects shifting
"""
super().__init__()
assert sample_style in ["range", "choice"], sample_style
self.is_range = sample_style == "range"
if isinstance(angle, (float, int)):
angle = (angle, angle)
if center is not None and isinstance(center[0], (float, int)):
center = (center, center)
self._init(locals())
def get_transform(self, image):
h, w = image.shape[:2]
center = None
if self.is_range:
angle = np.random.uniform(self.angle[0], self.angle[1])
if self.center is not None:
center = (
np.random.uniform(self.center[0][0], self.center[1][0]),
np.random.uniform(self.center[0][1], self.center[1][1]),
)
else:
angle = np.random.choice(self.angle)
if self.center is not None:
center = np.random.choice(self.center)
if center is not None:
center = (w * center[0], h * center[1]) # Convert to absolute coordinates
if angle % 360 == 0:
return NoOpTransform()
return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
class FixedSizeCrop(Augmentation):
"""
If `crop_size` is smaller than the input image size, then it uses a random crop of
the crop size. If `crop_size` is larger than the input image size, then it pads
the right and the bottom of the image to the crop size if `pad` is True, otherwise
it returns the smaller image.
"""
def __init__(
self,
crop_size: Tuple[int],
pad: bool = True,
pad_value: float = 128.0,
seg_pad_value: int = 255,
):
"""
Args:
crop_size: target image (height, width).
pad: if True, will pad images smaller than `crop_size` up to `crop_size`
pad_value: the padding value to the image.
seg_pad_value: the padding value to the segmentation mask.
"""
super().__init__()
self._init(locals())
def _get_crop(self, image: np.ndarray) -> Transform:
# Compute the image scale and scaled size.
input_size = image.shape[:2]
output_size = self.crop_size
# Add random crop if the image is scaled up.
max_offset = np.subtract(input_size, output_size)
max_offset = np.maximum(max_offset, 0)
offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0))
offset = np.round(offset).astype(int)
return CropTransform(
offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0]
)
def _get_pad(self, image: np.ndarray) -> Transform:
# Compute the image scale and scaled size.
input_size = image.shape[:2]
output_size = self.crop_size
# Add padding if the image is scaled down.
pad_size = np.subtract(output_size, input_size)
pad_size = np.maximum(pad_size, 0)
original_size = np.minimum(input_size, output_size)
return PadTransform(
0,
0,
pad_size[1],
pad_size[0],
original_size[1],
original_size[0],
self.pad_value,
self.seg_pad_value,
)
def get_transform(self, image: np.ndarray) -> TransformList:
transforms = [self._get_crop(image)]
if self.pad:
transforms.append(self._get_pad(image))
return TransformList(transforms)
class RandomCrop(Augmentation):
"""
Randomly crop a rectangle region out of an image.
"""
def __init__(self, crop_type: str, crop_size):
"""
Args:
crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range".
crop_size (tuple[float, float]): two floats, explained below.
- "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of
size (H, W). crop size should be in (0, 1]
- "relative_range": uniformly sample two values from [crop_size[0], 1]
and [crop_size[1]], 1], and use them as in "relative" crop type.
- "absolute" crop a (crop_size[0], crop_size[1]) region from input image.
crop_size must be smaller than the input image size.
- "absolute_range", for an input of size (H, W), uniformly sample H_crop in
[crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])].
Then crop a region (H_crop, W_crop).
"""
# TODO style of relative_range and absolute_range are not consistent:
# one takes (h, w) but another takes (min, max)
super().__init__()
assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"]
self._init(locals())
def get_transform(self, image):
h, w = image.shape[:2]
croph, cropw = self.get_crop_size((h, w))
assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
h0 = np.random.randint(h - croph + 1)
w0 = np.random.randint(w - cropw + 1)
return CropTransform(w0, h0, cropw, croph)
def get_crop_size(self, image_size):
"""
Args:
image_size (tuple): height, width
Returns:
crop_size (tuple): height, width in absolute pixels
"""
h, w = image_size
if self.crop_type == "relative":
ch, cw = self.crop_size
return int(h * ch + 0.5), int(w * cw + 0.5)
elif self.crop_type == "relative_range":
crop_size = np.asarray(self.crop_size, dtype=np.float32)
ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
return int(h * ch + 0.5), int(w * cw + 0.5)
elif self.crop_type == "absolute":
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
elif self.crop_type == "absolute_range":
assert self.crop_size[0] <= self.crop_size[1]
ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1)
cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1)
return ch, cw
else:
raise NotImplementedError("Unknown crop type {}".format(self.crop_type))
class RandomCrop_CategoryAreaConstraint(Augmentation):
"""
Similar to :class:`RandomCrop`, but find a cropping window such that no single category
occupies a ratio of more than `single_category_max_area` in semantic segmentation ground
truth, which can cause unstability in training. The function attempts to find such a valid
cropping window for at most 10 times.
"""
def __init__(
self,
crop_type: str,
crop_size,
single_category_max_area: float = 1.0,
ignored_category: int = None,
):
"""
Args:
crop_type, crop_size: same as in :class:`RandomCrop`
single_category_max_area: the maximum allowed area ratio of a
category. Set to 1.0 to disable
ignored_category: allow this category in the semantic segmentation
ground truth to exceed the area ratio. Usually set to the category
that's ignored in training.
"""
self.crop_aug = RandomCrop(crop_type, crop_size)
self._init(locals())
def get_transform(self, image, sem_seg):
if self.single_category_max_area >= 1.0:
return self.crop_aug.get_transform(image)
else:
h, w = sem_seg.shape
for _ in range(10):
crop_size = self.crop_aug.get_crop_size((h, w))
y0 = np.random.randint(h - crop_size[0] + 1)
x0 = np.random.randint(w - crop_size[1] + 1)
sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
labels, cnt = np.unique(sem_seg_temp, return_counts=True)
if self.ignored_category is not None:
cnt = cnt[labels != self.ignored_category]
if len(cnt) > 1 and np.max(cnt) < np.sum(cnt) * self.single_category_max_area:
break
crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
return crop_tfm
class RandomExtent(Augmentation):
"""
Outputs an image by cropping a random "subrect" of the source image.
The subrect can be parameterized to include pixels outside the source image,
in which case they will be set to zeros (i.e. black). The size of the output
image will vary with the size of the random subrect.
"""
def __init__(self, scale_range, shift_range):
"""
Args:
output_size (h, w): Dimensions of output image
scale_range (l, h): Range of input-to-output size scaling factor
shift_range (x, y): Range of shifts of the cropped subrect. The rect
is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)],
where (w, h) is the (width, height) of the input image. Set each
component to zero to crop at the image's center.
"""
super().__init__()
self._init(locals())
def get_transform(self, image):
img_h, img_w = image.shape[:2]
# Initialize src_rect to fit the input image.
src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h])
# Apply a random scaling to the src_rect.
src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1])
# Apply a random shift to the coordinates origin.
src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5)
src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5)
# Map src_rect coordinates into image coordinates (center at corner).
src_rect[0::2] += 0.5 * img_w
src_rect[1::2] += 0.5 * img_h
return ExtentTransform(
src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]),
output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])),
)
class RandomContrast(Augmentation):
"""
Randomly transforms image contrast.
Contrast intensity is uniformly sampled in (intensity_min, intensity_max).
- intensity < 1 will reduce contrast
- intensity = 1 will preserve the input image
- intensity > 1 will increase contrast
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
"""
def __init__(self, intensity_min, intensity_max):
"""
Args:
intensity_min (float): Minimum augmentation
intensity_max (float): Maximum augmentation
"""
super().__init__()
self._init(locals())
def get_transform(self, image):
w = np.random.uniform(self.intensity_min, self.intensity_max)
return BlendTransform(src_image=image.mean(), src_weight=1 - w, dst_weight=w)
class RandomBrightness(Augmentation):
"""
Randomly transforms image brightness.
Brightness intensity is uniformly sampled in (intensity_min, intensity_max).
- intensity < 1 will reduce brightness
- intensity = 1 will preserve the input image
- intensity > 1 will increase brightness
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
"""
def __init__(self, intensity_min, intensity_max):
"""
Args:
intensity_min (float): Minimum augmentation
intensity_max (float): Maximum augmentation
"""
super().__init__()
self._init(locals())
def get_transform(self, image):
w = np.random.uniform(self.intensity_min, self.intensity_max)
return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w)
class RandomSaturation(Augmentation):
"""
Randomly transforms saturation of an RGB image.
Input images are assumed to have 'RGB' channel order.
Saturation intensity is uniformly sampled in (intensity_min, intensity_max).
- intensity < 1 will reduce saturation (make the image more grayscale)
- intensity = 1 will preserve the input image
- intensity > 1 will increase saturation
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
"""
def __init__(self, intensity_min, intensity_max):
"""
Args:
intensity_min (float): Minimum augmentation (1 preserves input).
intensity_max (float): Maximum augmentation (1 preserves input).
"""
super().__init__()
self._init(locals())
def get_transform(self, image):
assert image.shape[-1] == 3, "RandomSaturation only works on RGB images"
w = np.random.uniform(self.intensity_min, self.intensity_max)
grayscale = image.dot([0.299, 0.587, 0.114])[:, :, np.newaxis]
return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
class RandomLighting(Augmentation):
"""
The "lighting" augmentation described in AlexNet, using fixed PCA over ImageNet.
Input images are assumed to have 'RGB' channel order.
The degree of color jittering is randomly sampled via a normal distribution,
with standard deviation given by the scale parameter.
"""
def __init__(self, scale):
"""
Args:
scale (float): Standard deviation of principal component weighting.
"""
super().__init__()
self._init(locals())
self.eigen_vecs = np.array(
[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
)
self.eigen_vals = np.array([0.2175, 0.0188, 0.0045])
def get_transform(self, image):
assert image.shape[-1] == 3, "RandomLighting only works on RGB images"
weights = np.random.normal(scale=self.scale, size=3)
return BlendTransform(
src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0
)
class RandomResize(Augmentation):
"""Randomly resize image to a target size in shape_list"""
def __init__(self, shape_list, interp=Image.BILINEAR):
"""
Args:
shape_list: a list of shapes in (h, w)
interp: PIL interpolation method
"""
self.shape_list = shape_list
self._init(locals())
def get_transform(self, image):
shape_idx = np.random.randint(low=0, high=len(self.shape_list))
h, w = self.shape_list[shape_idx]
return ResizeTransform(image.shape[0], image.shape[1], h, w, self.interp)
class MinIoURandomCrop(Augmentation):
"""Random crop the image & bboxes, the cropped patches have minimum IoU
requirement with original image & bboxes, the IoU threshold is randomly
selected from min_ious.
Args:
min_ious (tuple): minimum IoU threshold for all intersections with
bounding boxes
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
where a >= min_crop_size)
mode_trials: number of trials for sampling min_ious threshold
crop_trials: number of trials for sampling crop_size after cropping
"""
def __init__(
self,
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
min_crop_size=0.3,
mode_trials=1000,
crop_trials=50,
):
self.min_ious = min_ious
self.sample_mode = (1, *min_ious, 0)
self.min_crop_size = min_crop_size
self.mode_trials = mode_trials
self.crop_trials = crop_trials
def get_transform(self, image, boxes):
"""Call function to crop images and bounding boxes with minimum IoU
constraint.
Args:
boxes: ground truth boxes in (x1, y1, x2, y2) format
"""
if boxes is None:
return NoOpTransform()
h, w, c = image.shape
for _ in range(self.mode_trials):
mode = random.choice(self.sample_mode)
self.mode = mode
if mode == 1:
return NoOpTransform()
min_iou = mode
for _ in range(self.crop_trials):
new_w = random.uniform(self.min_crop_size * w, w)
new_h = random.uniform(self.min_crop_size * h, h)
# h / w in [0.5, 2]
if new_h / new_w < 0.5 or new_h / new_w > 2:
continue
left = random.uniform(w - new_w)
top = random.uniform(h - new_h)
patch = np.array((int(left), int(top), int(left + new_w), int(top + new_h)))
# Line or point crop is not allowed
if patch[2] == patch[0] or patch[3] == patch[1]:
continue
overlaps = pairwise_iou(
Boxes(patch.reshape(-1, 4)), Boxes(boxes.reshape(-1, 4))
).reshape(-1)
if len(overlaps) > 0 and overlaps.min() < min_iou:
continue
# center of boxes should inside the crop img
# only adjust boxes and instance masks when the gt is not empty
if len(overlaps) > 0:
# adjust boxes
def is_center_of_bboxes_in_patch(boxes, patch):
center = (boxes[:, :2] + boxes[:, 2:]) / 2
mask = (
(center[:, 0] > patch[0])
* (center[:, 1] > patch[1])
* (center[:, 0] < patch[2])
* (center[:, 1] < patch[3])
)
return mask
mask = is_center_of_bboxes_in_patch(boxes, patch)
if not mask.any():
continue
return CropTransform(int(left), int(top), int(new_w), int(new_h))
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
"""
See "Data Augmentation" tutorial for an overview of the system:
https://detectron2.readthedocs.io/tutorials/augmentation.html
"""
import numpy as np
import torch
import torch.nn.functional as F
from fvcore.transforms.transform import (
CropTransform,
HFlipTransform,
NoOpTransform,
Transform,
TransformList,
)
from PIL import Image
try:
import cv2 # noqa
except ImportError:
# OpenCV is an optional dependency at the moment
pass
__all__ = [
"ExtentTransform",
"ResizeTransform",
"RotationTransform",
"ColorTransform",
"PILColorTransform",
]
class ExtentTransform(Transform):
"""
Extracts a subregion from the source image and scales it to the output size.
The fill color is used to map pixels from the source rect that fall outside
the source image.
See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform
"""
def __init__(self, src_rect, output_size, interp=Image.BILINEAR, fill=0):
"""
Args:
src_rect (x0, y0, x1, y1): src coordinates
output_size (h, w): dst image size
interp: PIL interpolation methods
fill: Fill color used when src_rect extends outside image
"""
super().__init__()
self._set_attributes(locals())
def apply_image(self, img, interp=None):
h, w = self.output_size
if len(img.shape) > 2 and img.shape[2] == 1:
pil_image = Image.fromarray(img[:, :, 0], mode="L")
else:
pil_image = Image.fromarray(img)
pil_image = pil_image.transform(
size=(w, h),
method=Image.EXTENT,
data=self.src_rect,
resample=interp if interp else self.interp,
fill=self.fill,
)
ret = np.asarray(pil_image)
if len(img.shape) > 2 and img.shape[2] == 1:
ret = np.expand_dims(ret, -1)
return ret
def apply_coords(self, coords):
# Transform image center from source coordinates into output coordinates
# and then map the new origin to the corner of the output image.
h, w = self.output_size
x0, y0, x1, y1 = self.src_rect
new_coords = coords.astype(np.float32)
new_coords[:, 0] -= 0.5 * (x0 + x1)
new_coords[:, 1] -= 0.5 * (y0 + y1)
new_coords[:, 0] *= w / (x1 - x0)
new_coords[:, 1] *= h / (y1 - y0)
new_coords[:, 0] += 0.5 * w
new_coords[:, 1] += 0.5 * h
return new_coords
def apply_segmentation(self, segmentation):
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
return segmentation
class ResizeTransform(Transform):
"""
Resize the image to a target size.
"""
def __init__(self, h, w, new_h, new_w, interp=None):
"""
Args:
h, w (int): original image size
new_h, new_w (int): new image size
interp: PIL interpolation methods, defaults to bilinear.
"""
# TODO decide on PIL vs opencv
super().__init__()
if interp is None:
interp = Image.BILINEAR
self._set_attributes(locals())
def apply_image(self, img, interp=None):
assert img.shape[:2] == (self.h, self.w)
assert len(img.shape) <= 4
interp_method = interp if interp is not None else self.interp
if img.dtype == np.uint8:
if len(img.shape) > 2 and img.shape[2] == 1:
pil_image = Image.fromarray(img[:, :, 0], mode="L")
else:
pil_image = Image.fromarray(img)
pil_image = pil_image.resize((self.new_w, self.new_h), interp_method)
ret = np.asarray(pil_image)
if len(img.shape) > 2 and img.shape[2] == 1:
ret = np.expand_dims(ret, -1)
else:
# PIL only supports uint8
if any(x < 0 for x in img.strides):
img = np.ascontiguousarray(img)
img = torch.from_numpy(img)
shape = list(img.shape)
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
_PIL_RESIZE_TO_INTERPOLATE_MODE = {
Image.NEAREST: "nearest",
Image.BILINEAR: "bilinear",
Image.BICUBIC: "bicubic",
}
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[interp_method]
align_corners = None if mode == "nearest" else False
img = F.interpolate(
img, (self.new_h, self.new_w), mode=mode, align_corners=align_corners
)
shape[:2] = (self.new_h, self.new_w)
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
return ret
def apply_coords(self, coords):
coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w)
coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h)
return coords
def apply_segmentation(self, segmentation):
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
return segmentation
def inverse(self):
return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp)
class RotationTransform(Transform):
"""
This method returns a copy of this image, rotated the given
number of degrees counter clockwise around its center.
"""
def __init__(self, h, w, angle, expand=True, center=None, interp=None):
"""
Args:
h, w (int): original image size
angle (float): degrees for rotation
expand (bool): choose if the image should be resized to fit the whole
rotated image (default), or simply cropped
center (tuple (width, height)): coordinates of the rotation center
if left to None, the center will be fit to the center of each image
center has no effect if expand=True because it only affects shifting
interp: cv2 interpolation method, default cv2.INTER_LINEAR
"""
super().__init__()
image_center = np.array((w / 2, h / 2))
if center is None:
center = image_center
if interp is None:
interp = cv2.INTER_LINEAR
abs_cos, abs_sin = (abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle))))
if expand:
# find the new width and height bounds
bound_w, bound_h = np.rint(
[h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin]
).astype(int)
else:
bound_w, bound_h = w, h
self._set_attributes(locals())
self.rm_coords = self.create_rotation_matrix()
# Needed because of this problem https://github.com/opencv/opencv/issues/11784
self.rm_image = self.create_rotation_matrix(offset=-0.5)
def apply_image(self, img, interp=None):
"""
img should be a numpy array, formatted as Height * Width * Nchannels
"""
if len(img) == 0 or self.angle % 360 == 0:
return img
assert img.shape[:2] == (self.h, self.w)
interp = interp if interp is not None else self.interp
return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp)
def apply_coords(self, coords):
"""
coords should be a N * 2 array-like, containing N couples of (x, y) points
"""
coords = np.asarray(coords, dtype=float)
if len(coords) == 0 or self.angle % 360 == 0:
return coords
return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
def apply_segmentation(self, segmentation):
segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST)
return segmentation
def create_rotation_matrix(self, offset=0):
center = (self.center[0] + offset, self.center[1] + offset)
rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1)
if self.expand:
# Find the coordinates of the center of rotation in the new image
# The only point for which we know the future coordinates is the center of the image
rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :]
new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center
# shift the rotation center to the new coordinates
rm[:, 2] += new_center
return rm
def inverse(self):
"""
The inverse is to rotate it back with expand, and crop to get the original shape.
"""
if not self.expand: # Not possible to inverse if a part of the image is lost
raise NotImplementedError()
rotation = RotationTransform(
self.bound_h, self.bound_w, -self.angle, True, None, self.interp
)
crop = CropTransform(
(rotation.bound_w - self.w) // 2, (rotation.bound_h - self.h) // 2, self.w, self.h
)
return TransformList([rotation, crop])
class ColorTransform(Transform):
"""
Generic wrapper for any photometric transforms.
These transformations should only affect the color space and
not the coordinate space of the image (e.g. annotation
coordinates such as bounding boxes should not be changed)
"""
def __init__(self, op):
"""
Args:
op (Callable): operation to be applied to the image,
which takes in an ndarray and returns an ndarray.
"""
if not callable(op):
raise ValueError("op parameter should be callable")
super().__init__()
self._set_attributes(locals())
def apply_image(self, img):
return self.op(img)
def apply_coords(self, coords):
return coords
def inverse(self):
return NoOpTransform()
def apply_segmentation(self, segmentation):
return segmentation
class PILColorTransform(ColorTransform):
"""
Generic wrapper for PIL Photometric image transforms,
which affect the color space and not the coordinate
space of the image
"""
def __init__(self, op):
"""
Args:
op (Callable): operation to be applied to the image,
which takes in a PIL Image and returns a transformed
PIL Image.
For reference on possible operations see:
- https://pillow.readthedocs.io/en/stable/
"""
if not callable(op):
raise ValueError("op parameter should be callable")
super().__init__(op)
def apply_image(self, img):
img = Image.fromarray(img)
return np.asarray(super().apply_image(img))
def HFlip_rotated_box(transform, rotated_boxes):
"""
Apply the horizontal flip transform on rotated boxes.
Args:
rotated_boxes (ndarray): Nx5 floating point array of
(x_center, y_center, width, height, angle_degrees) format
in absolute coordinates.
"""
# Transform x_center
rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0]
# Transform angle
rotated_boxes[:, 4] = -rotated_boxes[:, 4]
return rotated_boxes
def Resize_rotated_box(transform, rotated_boxes):
"""
Apply the resizing transform on rotated boxes. For details of how these (approximation)
formulas are derived, please refer to :meth:`RotatedBoxes.scale`.
Args:
rotated_boxes (ndarray): Nx5 floating point array of
(x_center, y_center, width, height, angle_degrees) format
in absolute coordinates.
"""
scale_factor_x = transform.new_w * 1.0 / transform.w
scale_factor_y = transform.new_h * 1.0 / transform.h
rotated_boxes[:, 0] *= scale_factor_x
rotated_boxes[:, 1] *= scale_factor_y
theta = rotated_boxes[:, 4] * np.pi / 180.0
c = np.cos(theta)
s = np.sin(theta)
rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s))
rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c))
rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi
return rotated_boxes
HFlipTransform.register_type("rotated_box", HFlip_rotated_box)
ResizeTransform.register_type("rotated_box", Resize_rotated_box)
# not necessary any more with latest fvcore
NoOpTransform.register_type("rotated_box", lambda t, x: x)
# Copyright (c) Facebook, Inc. and its affiliates.
from .launch import *
from .train_loop import *
__all__ = [k for k in globals().keys() if not k.startswith("_")]
# prefer to let hooks and defaults live in separate namespaces (therefore not in __all__)
# but still make them available here
from .hooks import *
from .defaults import (
create_ddp_model,
default_argument_parser,
default_setup,
default_writers,
DefaultPredictor,
DefaultTrainer,
)
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
"""
This file contains components with some default boilerplate logic user may need
in training / testing. They will not work for everyone, but many users may find them useful.
The behavior of functions/classes in this file is subject to change,
since they are meant to represent the "common default behavior" people need in their projects.
"""
import argparse
import logging
import os
import sys
import weakref
from collections import OrderedDict
from typing import Optional
import torch
from fvcore.nn.precise_bn import get_bn_modules
from omegaconf import OmegaConf
from torch.nn.parallel import DistributedDataParallel
import detectron2.data.transforms as T
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import CfgNode, LazyConfig
from detectron2.data import (
MetadataCatalog,
build_detection_test_loader,
build_detection_train_loader,
)
from detectron2.evaluation import (
DatasetEvaluator,
inference_on_dataset,
print_csv_format,
verify_results,
)
from detectron2.modeling import build_model
from detectron2.solver import build_lr_scheduler, build_optimizer
from detectron2.utils import comm
from detectron2.utils.collect_env import collect_env_info
from detectron2.utils.env import seed_all_rng
from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
from . import hooks
from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase
__all__ = [
"create_ddp_model",
"default_argument_parser",
"default_setup",
"default_writers",
"DefaultPredictor",
"DefaultTrainer",
]
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
"""
Create a DistributedDataParallel model if there are >1 processes.
Args:
model: a torch.nn.Module
fp16_compression: add fp16 compression hooks to the ddp object.
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
""" # noqa
if comm.get_world_size() == 1:
return model
if "device_ids" not in kwargs:
kwargs["device_ids"] = [comm.get_local_rank()]
ddp = DistributedDataParallel(model, **kwargs)
if fp16_compression:
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
return ddp
def default_argument_parser(epilog=None):
"""
Create a parser with some common arguments used by detectron2 users.
Args:
epilog (str): epilog passed to ArgumentParser describing the usage.
Returns:
argparse.ArgumentParser:
"""
parser = argparse.ArgumentParser(
epilog=epilog
or f"""
Examples:
Run on single machine:
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
Change some config options:
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
Run on multiple machines:
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--resume",
action="store_true",
help="Whether to attempt to resume from the checkpoint directory. "
"See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
)
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
parser.add_argument(
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
)
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14
parser.add_argument(
"--dist-url",
default="tcp://127.0.0.1:{}".format(port),
help="initialization URL for pytorch distributed backend. See "
"https://pytorch.org/docs/stable/distributed.html for details.",
)
parser.add_argument(
"opts",
help="""
Modify config options at the end of the command. For Yacs configs, use
space-separated "PATH.KEY VALUE" pairs.
For python-based LazyConfig, use "path.key=value".
""".strip(),
default=None,
nargs=argparse.REMAINDER,
)
return parser
def _try_get_key(cfg, *keys, default=None):
"""
Try select keys from cfg until the first key that exists. Otherwise return default.
"""
if isinstance(cfg, CfgNode):
cfg = OmegaConf.create(cfg.dump())
for k in keys:
none = object()
p = OmegaConf.select(cfg, k, default=none)
if p is not none:
return p
return default
def _highlight(code, filename):
try:
import pygments
except ImportError:
return code
from pygments.lexers import Python3Lexer, YamlLexer
from pygments.formatters import Terminal256Formatter
lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
return code
def default_setup(cfg, args):
"""
Perform some basic common setups at the beginning of a job, including:
1. Set up the detectron2 logger
2. Log basic information about environment, cmdline arguments, and config
3. Backup the config to the output directory
Args:
cfg (CfgNode or omegaconf.DictConfig): the full config to be used
args (argparse.NameSpace): the command line arguments to be logged
"""
output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
if comm.is_main_process() and output_dir:
PathManager.mkdirs(output_dir)
rank = comm.get_rank()
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
logger = setup_logger(output_dir, distributed_rank=rank)
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
logger.info("Environment info:\n" + collect_env_info())
logger.info("Command line arguments: " + str(args))
if hasattr(args, "config_file") and args.config_file != "":
logger.info(
"Contents of args.config_file={}:\n{}".format(
args.config_file,
_highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
)
)
if comm.is_main_process() and output_dir:
# Note: some of our scripts may expect the existence of
# config.yaml in output directory
path = os.path.join(output_dir, "config.yaml")
if isinstance(cfg, CfgNode):
logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
with PathManager.open(path, "w") as f:
f.write(cfg.dump())
else:
LazyConfig.save(cfg, path)
logger.info("Full config saved to {}".format(path))
# make sure each worker has a different, yet deterministic seed if specified
seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
seed_all_rng(None if seed < 0 else seed + rank)
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
# typical validation set.
if not (hasattr(args, "eval_only") and args.eval_only):
torch.backends.cudnn.benchmark = _try_get_key(
cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
)
def default_writers(output_dir: str, max_iter: Optional[int] = None):
"""
Build a list of :class:`EventWriter` to be used.
It now consists of a :class:`CommonMetricPrinter`,
:class:`TensorboardXWriter` and :class:`JSONWriter`.
Args:
output_dir: directory to store JSON metrics and tensorboard events
max_iter: the total number of iterations
Returns:
list[EventWriter]: a list of :class:`EventWriter` objects.
"""
PathManager.mkdirs(output_dir)
return [
# It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(max_iter),
JSONWriter(os.path.join(output_dir, "metrics.json")),
TensorboardXWriter(output_dir),
]
class DefaultPredictor:
"""
Create a simple end-to-end predictor with the given config that runs on
single device for a single input image.
Compared to using the model directly, this class does the following additions:
1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
4. Take one input image and produce a single output, instead of a batch.
This is meant for simple demo purposes, so it does the above steps automatically.
This is not meant for benchmarks or running complicated inference logic.
If you'd like to do anything more complicated, please refer to its source code as
examples to build and use the model manually.
Attributes:
metadata (Metadata): the metadata of the underlying dataset, obtained from
cfg.DATASETS.TEST.
Examples:
::
pred = DefaultPredictor(cfg)
inputs = cv2.imread("input.jpg")
outputs = pred(inputs)
"""
def __init__(self, cfg):
self.cfg = cfg.clone() # cfg can be modified by model
self.model = build_model(self.cfg)
self.model.eval()
if len(cfg.DATASETS.TEST):
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
self.aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __call__(self, original_image):
"""
Args:
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
Returns:
predictions (dict):
the output of the model for one image only.
See :doc:`/tutorials/models` for details about the format.
"""
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
image.to(self.cfg.MODEL.DEVICE)
inputs = {"image": image, "height": height, "width": width}
predictions = self.model([inputs])[0]
return predictions
class DefaultTrainer(TrainerBase):
"""
A trainer with default training logic. It does the following:
1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
defined by the given config. Create a LR scheduler defined by the config.
2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
`resume_or_load` is called.
3. Register a few common hooks defined by the config.
It is created to simplify the **standard model training workflow** and reduce code boilerplate
for users who only need the standard training workflow, with standard features.
It means this class makes *many assumptions* about your training logic that
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
:class:`SimpleTrainer` are too much for research.
The code of this class has been annotated about restrictive assumptions it makes.
When they do not work for you, you're encouraged to:
1. Overwrite methods of this class, OR:
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
nothing else. You can then add your own hooks if needed. OR:
3. Write your own training loop similar to `tools/plain_train_net.py`.
See the :doc:`/tutorials/training` tutorials for more details.
Note that the behavior of this class, like other functions/classes in
this file, is not stable, since it is meant to represent the "common default behavior".
It is only guaranteed to work well with the standard models and training workflow in detectron2.
To obtain more stable behavior, write your own training logic with other public APIs.
Examples:
::
trainer = DefaultTrainer(cfg)
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
trainer.train()
Attributes:
scheduler:
checkpointer (DetectionCheckpointer):
cfg (CfgNode):
"""
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
super().__init__()
logger = logging.getLogger("detectron2")
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger()
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
model = create_ddp_model(model, broadcast_buffers=False)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
trainer=weakref.proxy(self),
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
def resume_or_load(self, resume=True):
"""
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
a `last_checkpoint` file), resume from the file. Resuming means loading all
available states (eg. optimizer and scheduler) and update iteration counter
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
Otherwise, this is considered as an independent training. The method will load model
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
from iteration 0.
Args:
resume (bool): whether to do resume or not
"""
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
if resume and self.checkpointer.has_checkpoint():
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration
self.start_iter = self.iter + 1
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(),
(
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None
),
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
if comm.is_main_process():
# Here the default print/log frequency of each writer is used.
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
return ret
def build_writers(self):
"""
Build a list of writers to be used using :func:`default_writers()`.
If you'd like a different list of writers, you can overwrite it in
your trainer.
Returns:
list[EventWriter]: a list of :class:`EventWriter` objects.
"""
return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
def train(self):
"""
Run training.
Returns:
OrderedDict of results, if evaluation is enabled. Otherwise None.
"""
super().train(self.start_iter, self.max_iter)
if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
assert hasattr(
self, "_last_eval_results"
), "No evaluation results obtained during training!"
verify_results(self.cfg, self._last_eval_results)
return self._last_eval_results
def run_step(self):
self._trainer.iter = self.iter
self._trainer.run_step()
def state_dict(self):
ret = super().state_dict()
ret["_trainer"] = self._trainer.state_dict()
return ret
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self._trainer.load_state_dict(state_dict["_trainer"])
@classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
It now calls :func:`detectron2.modeling.build_model`.
Overwrite it if you'd like a different model.
"""
model = build_model(cfg)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
return model
@classmethod
def build_optimizer(cls, cfg, model):
"""
Returns:
torch.optim.Optimizer:
It now calls :func:`detectron2.solver.build_optimizer`.
Overwrite it if you'd like a different optimizer.
"""
return build_optimizer(cfg, model)
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
return build_lr_scheduler(cfg, optimizer)
@classmethod
def build_train_loader(cls, cfg):
"""
Returns:
iterable
It now calls :func:`detectron2.data.build_detection_train_loader`.
Overwrite it if you'd like a different data loader.
"""
return build_detection_train_loader(cfg)
@classmethod
def build_test_loader(cls, cfg, dataset_name):
"""
Returns:
iterable
It now calls :func:`detectron2.data.build_detection_test_loader`.
Overwrite it if you'd like a different data loader.
"""
return build_detection_test_loader(cfg, dataset_name)
@classmethod
def build_evaluator(cls, cfg, dataset_name):
"""
Returns:
DatasetEvaluator or None
It is not implemented by default.
"""
raise NotImplementedError(
"""
If you want DefaultTrainer to automatically run evaluation,
please implement `build_evaluator()` in subclasses (see train_net.py for example).
Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
"""
)
@classmethod
def test(cls, cfg, model, evaluators=None):
"""
Evaluate the given model. The given model is expected to already contain
weights to evaluate.
Args:
cfg (CfgNode):
model (nn.Module):
evaluators (list[DatasetEvaluator] or None): if None, will call
:meth:`build_evaluator`. Otherwise, must have the same length as
``cfg.DATASETS.TEST``.
Returns:
dict: a dict of result metrics
"""
logger = logging.getLogger(__name__)
if isinstance(evaluators, DatasetEvaluator):
evaluators = [evaluators]
if evaluators is not None:
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
len(cfg.DATASETS.TEST), len(evaluators)
)
results = OrderedDict()
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
data_loader = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
if evaluators is not None:
evaluator = evaluators[idx]
else:
try:
evaluator = cls.build_evaluator(cfg, dataset_name)
except NotImplementedError:
logger.warn(
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
"or implement its `build_evaluator` method."
)
results[dataset_name] = {}
continue
results_i = inference_on_dataset(model, data_loader, evaluator)
results[dataset_name] = results_i
if comm.is_main_process():
assert isinstance(
results_i, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results_i
)
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
print_csv_format(results_i)
if len(results) == 1:
results = list(results.values())[0]
return results
@staticmethod
def auto_scale_workers(cfg, num_workers: int):
"""
When the config is defined for certain number of workers (according to
``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
workers currently in use, returns a new cfg where the total batch size
is scaled so that the per-GPU batch size stays the same as the
original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
Other config options are also scaled accordingly:
* training steps and warmup steps are scaled inverse proportionally.
* learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
For example, with the original config like the following:
.. code-block:: yaml
IMS_PER_BATCH: 16
BASE_LR: 0.1
REFERENCE_WORLD_SIZE: 8
MAX_ITER: 5000
STEPS: (4000,)
CHECKPOINT_PERIOD: 1000
When this config is used on 16 GPUs instead of the reference number 8,
calling this method will return a new config with:
.. code-block:: yaml
IMS_PER_BATCH: 32
BASE_LR: 0.2
REFERENCE_WORLD_SIZE: 16
MAX_ITER: 2500
STEPS: (2000,)
CHECKPOINT_PERIOD: 500
Note that both the original config and this new config can be trained on 16 GPUs.
It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
Returns:
CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
"""
old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
if old_world_size == 0 or old_world_size == num_workers:
return cfg
cfg = cfg.clone()
frozen = cfg.is_frozen()
cfg.defrost()
assert (
cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
), "Invalid REFERENCE_WORLD_SIZE in config!"
scale = num_workers / old_world_size
bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
logger = logging.getLogger(__name__)
logger.info(
f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
f"max_iter={max_iter}, warmup={warmup_iter}."
)
if frozen:
cfg.freeze()
return cfg
# Access basic attributes from the underlying trainer
for _attr in ["model", "data_loader", "optimizer"]:
setattr(
DefaultTrainer,
_attr,
property(
# getter
lambda self, x=_attr: getattr(self._trainer, x),
# setter
lambda self, value, x=_attr: setattr(self._trainer, x, value),
),
)
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
import datetime
import itertools
import logging
import math
import operator
import os
import tempfile
import time
import warnings
from collections import Counter
import torch
from fvcore.common.checkpoint import Checkpointer
from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
from fvcore.common.param_scheduler import ParamScheduler
from fvcore.common.timer import Timer
from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats
import detectron2.utils.comm as comm
from detectron2.evaluation.testing import flatten_results_dict
from detectron2.solver import LRMultiplier
from detectron2.solver import LRScheduler as _LRScheduler
from detectron2.utils.events import EventStorage, EventWriter
from detectron2.utils.file_io import PathManager
from .train_loop import HookBase
__all__ = [
"CallbackHook",
"IterationTimer",
"PeriodicWriter",
"PeriodicCheckpointer",
"BestCheckpointer",
"LRScheduler",
"AutogradProfiler",
"EvalHook",
"PreciseBN",
"TorchProfiler",
"TorchMemoryStats",
]
"""
Implement some common hooks.
"""
class CallbackHook(HookBase):
"""
Create a hook using callback functions provided by the user.
"""
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
"""
Each argument is a function that takes one argument: the trainer.
"""
self._before_train = before_train
self._before_step = before_step
self._after_step = after_step
self._after_train = after_train
def before_train(self):
if self._before_train:
self._before_train(self.trainer)
def after_train(self):
if self._after_train:
self._after_train(self.trainer)
# The functions may be closures that hold reference to the trainer
# Therefore, delete them to avoid circular reference.
del self._before_train, self._after_train
del self._before_step, self._after_step
def before_step(self):
if self._before_step:
self._before_step(self.trainer)
def after_step(self):
if self._after_step:
self._after_step(self.trainer)
class IterationTimer(HookBase):
"""
Track the time spent for each iteration (each run_step call in the trainer).
Print a summary in the end of training.
This hook uses the time between the call to its :meth:`before_step`
and :meth:`after_step` methods.
Under the convention that :meth:`before_step` of all hooks should only
take negligible amount of time, the :class:`IterationTimer` hook should be
placed at the beginning of the list of hooks to obtain accurate timing.
"""
def __init__(self, warmup_iter=3):
"""
Args:
warmup_iter (int): the number of iterations at the beginning to exclude
from timing.
"""
self._warmup_iter = warmup_iter
self._step_timer = Timer()
self._start_time = time.perf_counter()
self._total_timer = Timer()
def before_train(self):
self._start_time = time.perf_counter()
self._total_timer.reset()
self._total_timer.pause()
def after_train(self):
logger = logging.getLogger(__name__)
total_time = time.perf_counter() - self._start_time
total_time_minus_hooks = self._total_timer.seconds()
hook_time = total_time - total_time_minus_hooks
num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter
if num_iter > 0 and total_time_minus_hooks > 0:
# Speed is meaningful only after warmup
# NOTE this format is parsed by grep in some scripts
logger.info(
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
num_iter,
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
total_time_minus_hooks / num_iter,
)
)
logger.info(
"Total training time: {} ({} on hooks)".format(
str(datetime.timedelta(seconds=int(total_time))),
str(datetime.timedelta(seconds=int(hook_time))),
)
)
def before_step(self):
self._step_timer.reset()
self._total_timer.resume()
def after_step(self):
# +1 because we're in after_step, the current step is done
# but not yet counted
iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1
if iter_done >= self._warmup_iter:
sec = self._step_timer.seconds()
self.trainer.storage.put_scalars(time=sec)
else:
self._start_time = time.perf_counter()
self._total_timer.reset()
self._total_timer.pause()
class PeriodicWriter(HookBase):
"""
Write events to EventStorage (by calling ``writer.write()``) periodically.
It is executed every ``period`` iterations and after the last iteration.
Note that ``period`` does not affect how data is smoothed by each writer.
"""
def __init__(self, writers, period=20):
"""
Args:
writers (list[EventWriter]): a list of EventWriter objects
period (int):
"""
self._writers = writers
for w in writers:
assert isinstance(w, EventWriter), w
self._period = period
def after_step(self):
if (self.trainer.iter + 1) % self._period == 0 or (
self.trainer.iter == self.trainer.max_iter - 1
):
for writer in self._writers:
writer.write()
def after_train(self):
for writer in self._writers:
# If any new data is found (e.g. produced by other after_train),
# write them before closing
writer.write()
writer.close()
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
"""
Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
Note that when used as a hook,
it is unable to save additional data other than what's defined
by the given `checkpointer`.
It is executed every ``period`` iterations and after the last iteration.
"""
def before_train(self):
self.max_iter = self.trainer.max_iter
def after_step(self):
# No way to use **kwargs
self.step(self.trainer.iter)
class BestCheckpointer(HookBase):
"""
Checkpoints best weights based off given metric.
This hook should be used in conjunction to and executed after the hook
that produces the metric, e.g. `EvalHook`.
"""
def __init__(
self,
eval_period: int,
checkpointer: Checkpointer,
val_metric: str,
mode: str = "max",
file_prefix: str = "model_best",
) -> None:
"""
Args:
eval_period (int): the period `EvalHook` is set to run.
checkpointer: the checkpointer object used to save checkpoints.
val_metric (str): validation metric to track for best checkpoint, e.g. "bbox/AP50"
mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be
maximized or minimized, e.g. for "bbox/AP50" it should be "max"
file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best"
"""
self._logger = logging.getLogger(__name__)
self._period = eval_period
self._val_metric = val_metric
assert mode in [
"max",
"min",
], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.'
if mode == "max":
self._compare = operator.gt
else:
self._compare = operator.lt
self._checkpointer = checkpointer
self._file_prefix = file_prefix
self.best_metric = None
self.best_iter = None
def _update_best(self, val, iteration):
if math.isnan(val) or math.isinf(val):
return False
self.best_metric = val
self.best_iter = iteration
return True
def _best_checking(self):
metric_tuple = self.trainer.storage.latest().get(self._val_metric)
if metric_tuple is None:
self._logger.warning(
f"Given val metric {self._val_metric} does not seem to be computed/stored."
"Will not be checkpointing based on it."
)
return
else:
latest_metric, metric_iter = metric_tuple
if self.best_metric is None:
if self._update_best(latest_metric, metric_iter):
additional_state = {"iteration": metric_iter}
self._checkpointer.save(f"{self._file_prefix}", **additional_state)
self._logger.info(
f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps"
)
elif self._compare(latest_metric, self.best_metric):
additional_state = {"iteration": metric_iter}
self._checkpointer.save(f"{self._file_prefix}", **additional_state)
self._logger.info(
f"Saved best model as latest eval score for {self._val_metric} is "
f"{latest_metric:0.5f}, better than last best score "
f"{self.best_metric:0.5f} @ iteration {self.best_iter}."
)
self._update_best(latest_metric, metric_iter)
else:
self._logger.info(
f"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, "
f"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}."
)
def after_step(self):
# same conditions as `EvalHook`
next_iter = self.trainer.iter + 1
if (
self._period > 0
and next_iter % self._period == 0
and next_iter != self.trainer.max_iter
):
self._best_checking()
def after_train(self):
# same conditions as `EvalHook`
if self.trainer.iter + 1 >= self.trainer.max_iter:
self._best_checking()
class LRScheduler(HookBase):
"""
A hook which executes a torch builtin LR scheduler and summarizes the LR.
It is executed after every iteration.
"""
def __init__(self, optimizer=None, scheduler=None):
"""
Args:
optimizer (torch.optim.Optimizer):
scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler):
if a :class:`ParamScheduler` object, it defines the multiplier over the base LR
in the optimizer.
If any argument is not given, will try to obtain it from the trainer.
"""
self._optimizer = optimizer
self._scheduler = scheduler
def before_train(self):
self._optimizer = self._optimizer or self.trainer.optimizer
if isinstance(self.scheduler, ParamScheduler):
self._scheduler = LRMultiplier(
self._optimizer,
self.scheduler,
self.trainer.max_iter,
last_iter=self.trainer.iter - 1,
)
self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer)
@staticmethod
def get_best_param_group_id(optimizer):
# NOTE: some heuristics on what LR to summarize
# summarize the param group with most parameters
largest_group = max(len(g["params"]) for g in optimizer.param_groups)
if largest_group == 1:
# If all groups have one parameter,
# then find the most common initial LR, and use it for summary
lr_count = Counter([g["lr"] for g in optimizer.param_groups])
lr = lr_count.most_common()[0][0]
for i, g in enumerate(optimizer.param_groups):
if g["lr"] == lr:
return i
else:
for i, g in enumerate(optimizer.param_groups):
if len(g["params"]) == largest_group:
return i
def after_step(self):
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
self.scheduler.step()
@property
def scheduler(self):
return self._scheduler or self.trainer.scheduler
def state_dict(self):
if isinstance(self.scheduler, _LRScheduler):
return self.scheduler.state_dict()
return {}
def load_state_dict(self, state_dict):
if isinstance(self.scheduler, _LRScheduler):
logger = logging.getLogger(__name__)
logger.info("Loading scheduler from state_dict ...")
self.scheduler.load_state_dict(state_dict)
class TorchProfiler(HookBase):
"""
A hook which runs `torch.profiler.profile`.
Examples:
::
hooks.TorchProfiler(
lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
)
The above example will run the profiler for iteration 10~20 and dump
results to ``OUTPUT_DIR``. We did not profile the first few iterations
because they are typically slower than the rest.
The result files can be loaded in the ``chrome://tracing`` page in chrome browser,
and the tensorboard visualizations can be visualized using
``tensorboard --logdir OUTPUT_DIR/log``
"""
def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True):
"""
Args:
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
and returns whether to enable the profiler.
It will be called once every step, and can be used to select which steps to profile.
output_dir (str): the output directory to dump tracing files.
activities (iterable): same as in `torch.profiler.profile`.
save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/
"""
self._enable_predicate = enable_predicate
self._activities = activities
self._output_dir = output_dir
self._save_tensorboard = save_tensorboard
def before_step(self):
if self._enable_predicate(self.trainer):
if self._save_tensorboard:
on_trace_ready = torch.profiler.tensorboard_trace_handler(
os.path.join(
self._output_dir,
"log",
"profiler-tensorboard-iter{}".format(self.trainer.iter),
),
f"worker{comm.get_rank()}",
)
else:
on_trace_ready = None
self._profiler = torch.profiler.profile(
activities=self._activities,
on_trace_ready=on_trace_ready,
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
)
self._profiler.__enter__()
else:
self._profiler = None
def after_step(self):
if self._profiler is None:
return
self._profiler.__exit__(None, None, None)
if not self._save_tensorboard:
PathManager.mkdirs(self._output_dir)
out_file = os.path.join(
self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
)
if "://" not in out_file:
self._profiler.export_chrome_trace(out_file)
else:
# Support non-posix filesystems
with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
tmp_file = os.path.join(d, "tmp.json")
self._profiler.export_chrome_trace(tmp_file)
with open(tmp_file) as f:
content = f.read()
with PathManager.open(out_file, "w") as f:
f.write(content)
class AutogradProfiler(TorchProfiler):
"""
A hook which runs `torch.autograd.profiler.profile`.
Examples:
::
hooks.AutogradProfiler(
lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
)
The above example will run the profiler for iteration 10~20 and dump
results to ``OUTPUT_DIR``. We did not profile the first few iterations
because they are typically slower than the rest.
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
Note:
When used together with NCCL on older version of GPUs,
autograd profiler may cause deadlock because it unnecessarily allocates
memory on every device it sees. The memory management calls, if
interleaved with NCCL calls, lead to deadlock on GPUs that do not
support ``cudaLaunchCooperativeKernelMultiDevice``.
"""
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
"""
Args:
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
and returns whether to enable the profiler.
It will be called once every step, and can be used to select which steps to profile.
output_dir (str): the output directory to dump tracing files.
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
"""
warnings.warn("AutogradProfiler has been deprecated in favor of TorchProfiler.")
self._enable_predicate = enable_predicate
self._use_cuda = use_cuda
self._output_dir = output_dir
def before_step(self):
if self._enable_predicate(self.trainer):
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
self._profiler.__enter__()
else:
self._profiler = None
class EvalHook(HookBase):
"""
Run an evaluation function periodically, and at the end of training.
It is executed every ``eval_period`` iterations and after the last iteration.
"""
def __init__(self, eval_period, eval_function, eval_after_train=True):
"""
Args:
eval_period (int): the period to run `eval_function`. Set to 0 to
not evaluate periodically (but still evaluate after the last iteration
if `eval_after_train` is True).
eval_function (callable): a function which takes no arguments, and
returns a nested dict of evaluation metrics.
eval_after_train (bool): whether to evaluate after the last iteration
Note:
This hook must be enabled in all or none workers.
If you would like only certain workers to perform evaluation,
give other workers a no-op function (`eval_function=lambda: None`).
"""
self._period = eval_period
self._func = eval_function
self._eval_after_train = eval_after_train
def _do_eval(self):
results = self._func()
if results:
assert isinstance(
results, dict
), "Eval function must return a dict. Got {} instead.".format(results)
flattened_results = flatten_results_dict(results)
for k, v in flattened_results.items():
try:
v = float(v)
except Exception as e:
raise ValueError(
"[EvalHook] eval_function should return a nested dict of float. "
"Got '{}: {}' instead.".format(k, v)
) from e
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
# Evaluation may take different time among workers.
# A barrier make them start the next iteration together.
comm.synchronize()
def after_step(self):
next_iter = self.trainer.iter + 1
if self._period > 0 and next_iter % self._period == 0:
# do the last eval in after_train
if next_iter != self.trainer.max_iter:
self._do_eval()
def after_train(self):
# This condition is to prevent the eval from running after a failed training
if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter:
self._do_eval()
# func is likely a closure that holds reference to the trainer
# therefore we clean it to avoid circular reference in the end
del self._func
class PreciseBN(HookBase):
"""
The standard implementation of BatchNorm uses EMA in inference, which is
sometimes suboptimal.
This class computes the true average of statistics rather than the moving average,
and put true averages to every BN layer in the given model.
It is executed every ``period`` iterations and after the last iteration.
"""
def __init__(self, period, model, data_loader, num_iter):
"""
Args:
period (int): the period this hook is run, or 0 to not run during training.
The hook will always run in the end of training.
model (nn.Module): a module whose all BN layers in training mode will be
updated by precise BN.
Note that user is responsible for ensuring the BN layers to be
updated are in training mode when this hook is triggered.
data_loader (iterable): it will produce data to be run by `model(data)`.
num_iter (int): number of iterations used to compute the precise
statistics.
"""
self._logger = logging.getLogger(__name__)
if len(get_bn_modules(model)) == 0:
self._logger.info(
"PreciseBN is disabled because model does not contain BN layers in training mode."
)
self._disabled = True
return
self._model = model
self._data_loader = data_loader
self._num_iter = num_iter
self._period = period
self._disabled = False
self._data_iter = None
def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self.update_stats()
def update_stats(self):
"""
Update the model with precise statistics. Users can manually call this method.
"""
if self._disabled:
return
if self._data_iter is None:
self._data_iter = iter(self._data_loader)
def data_loader():
for num_iter in itertools.count(1):
if num_iter % 100 == 0:
self._logger.info(
"Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
)
# This way we can reuse the same iterator
yield next(self._data_iter)
with EventStorage(): # capture events in a new storage to discard them
self._logger.info(
"Running precise-BN for {} iterations... ".format(self._num_iter)
+ "Note that this could produce different statistics every time."
)
update_bn_stats(self._model, data_loader(), self._num_iter)
class TorchMemoryStats(HookBase):
"""
Writes pytorch's cuda memory statistics periodically.
"""
def __init__(self, period=20, max_runs=10):
"""
Args:
period (int): Output stats each 'period' iterations
max_runs (int): Stop the logging after 'max_runs'
"""
self._logger = logging.getLogger(__name__)
self._period = period
self._max_runs = max_runs
self._runs = 0
def after_step(self):
if self._runs > self._max_runs:
return
if (self.trainer.iter + 1) % self._period == 0 or (
self.trainer.iter == self.trainer.max_iter - 1
):
if torch.cuda.is_available():
max_reserved_mb = torch.cuda.max_memory_reserved() / 1024.0 / 1024.0
reserved_mb = torch.cuda.memory_reserved() / 1024.0 / 1024.0
max_allocated_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
allocated_mb = torch.cuda.memory_allocated() / 1024.0 / 1024.0
self._logger.info(
(
" iter: {} "
" max_reserved_mem: {:.0f}MB "
" reserved_mem: {:.0f}MB "
" max_allocated_mem: {:.0f}MB "
" allocated_mem: {:.0f}MB "
).format(
self.trainer.iter,
max_reserved_mb,
reserved_mb,
max_allocated_mb,
allocated_mb,
)
)
self._runs += 1
if self._runs == self._max_runs:
mem_summary = torch.cuda.memory_summary()
self._logger.info("\n" + mem_summary)
torch.cuda.reset_peak_memory_stats()
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
from datetime import timedelta
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from detectron2.utils import comm
__all__ = ["DEFAULT_TIMEOUT", "launch"]
DEFAULT_TIMEOUT = timedelta(minutes=30)
def _find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def launch(
main_func,
# Should be num_processes_per_machine, but kept for compatibility.
num_gpus_per_machine,
num_machines=1,
machine_rank=0,
dist_url=None,
args=(),
timeout=DEFAULT_TIMEOUT,
):
"""
Launch multi-process or distributed training.
This function must be called on all machines involved in the training.
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
Args:
main_func: a function that will be called by `main_func(*args)`
num_gpus_per_machine (int): number of processes per machine. When
using GPUs, this should be the number of GPUs.
num_machines (int): the total number of machines
machine_rank (int): the rank of this machine
dist_url (str): url to connect to for distributed jobs, including protocol
e.g. "tcp://127.0.0.1:8686".
Can be set to "auto" to automatically select a free port on localhost
timeout (timedelta): timeout of the distributed workers
args (tuple): arguments passed to main_func
"""
world_size = num_machines * num_gpus_per_machine
if world_size > 1:
# https://github.com/pytorch/pytorch/pull/14391
# TODO prctl in spawned processes
if dist_url == "auto":
assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if num_machines > 1 and dist_url.startswith("file://"):
logger = logging.getLogger(__name__)
logger.warning(
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
)
mp.start_processes(
_distributed_worker,
nprocs=num_gpus_per_machine,
args=(
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
args,
timeout,
),
daemon=False,
)
else:
main_func(*args)
def _distributed_worker(
local_rank,
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
args,
timeout=DEFAULT_TIMEOUT,
):
has_gpu = torch.cuda.is_available()
if has_gpu:
assert num_gpus_per_machine <= torch.cuda.device_count()
global_rank = machine_rank * num_gpus_per_machine + local_rank
try:
dist.init_process_group(
backend="NCCL" if has_gpu else "GLOO",
init_method=dist_url,
world_size=world_size,
rank=global_rank,
timeout=timeout,
)
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Process group URL: {}".format(dist_url))
raise e
# Setup the local process group.
comm.create_local_process_group(num_gpus_per_machine)
if has_gpu:
torch.cuda.set_device(local_rank)
# synchronize is needed here to prevent a possible timeout after calling init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
main_func(*args)
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
import concurrent.futures
import logging
import numpy as np
import time
import weakref
from typing import List, Mapping, Optional
import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
import detectron2.utils.comm as comm
from detectron2.utils.events import EventStorage, get_event_storage
from detectron2.utils.logger import _log_api_usage
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer", "AMPTrainer"]
class HookBase:
"""
Base class for hooks that can be registered with :class:`TrainerBase`.
Each hook can implement 4 methods. The way they are called is demonstrated
in the following snippet:
::
hook.before_train()
for iter in range(start_iter, max_iter):
hook.before_step()
trainer.run_step()
hook.after_step()
iter += 1
hook.after_train()
Notes:
1. In the hook method, users can access ``self.trainer`` to access more
properties about the context (e.g., model, current iteration, or config
if using :class:`DefaultTrainer`).
2. A hook that does something in :meth:`before_step` can often be
implemented equivalently in :meth:`after_step`.
If the hook takes non-trivial time, it is strongly recommended to
implement the hook in :meth:`after_step` instead of :meth:`before_step`.
The convention is that :meth:`before_step` should only take negligible time.
Following this convention will allow hooks that do care about the difference
between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
function properly.
"""
trainer: "TrainerBase" = None
"""
A weak reference to the trainer object. Set by the trainer when the hook is registered.
"""
def before_train(self):
"""
Called before the first iteration.
"""
pass
def after_train(self):
"""
Called after the last iteration.
"""
pass
def before_step(self):
"""
Called before each iteration.
"""
pass
def after_backward(self):
"""
Called after the backward pass of each iteration.
"""
pass
def after_step(self):
"""
Called after each iteration.
"""
pass
def state_dict(self):
"""
Hooks are stateless by default, but can be made checkpointable by
implementing `state_dict` and `load_state_dict`.
"""
return {}
class TrainerBase:
"""
Base class for iterative trainer with hooks.
The only assumption we made here is: the training runs in a loop.
A subclass can implement what the loop is.
We made no assumptions about the existence of dataloader, optimizer, model, etc.
Attributes:
iter(int): the current iteration.
start_iter(int): The iteration to start with.
By convention the minimum possible value is 0.
max_iter(int): The iteration to end training.
storage(EventStorage): An EventStorage that's opened during the course of training.
"""
def __init__(self) -> None:
self._hooks: List[HookBase] = []
self.iter: int = 0
self.start_iter: int = 0
self.max_iter: int
self.storage: EventStorage
_log_api_usage("trainer." + self.__class__.__name__)
def register_hooks(self, hooks: List[Optional[HookBase]]) -> None:
"""
Register hooks to the trainer. The hooks are executed in the order
they are registered.
Args:
hooks (list[Optional[HookBase]]): list of hooks
"""
hooks = [h for h in hooks if h is not None]
for h in hooks:
assert isinstance(h, HookBase)
# To avoid circular reference, hooks and trainer cannot own each other.
# This normally does not matter, but will cause memory leak if the
# involved objects contain __del__:
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
h.trainer = weakref.proxy(self)
self._hooks.extend(hooks)
def train(self, start_iter: int, max_iter: int):
"""
Args:
start_iter, max_iter (int): See docs above
"""
logger = logging.getLogger(__name__)
logger.info("Starting training from iteration {}".format(start_iter))
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
# self.iter == max_iter can be used by `after_train` to
# tell whether the training successfully finished or failed
# due to exceptions.
self.iter += 1
except Exception:
logger.exception("Exception during training:")
raise
finally:
self.after_train()
def before_train(self):
for h in self._hooks:
h.before_train()
def after_train(self):
self.storage.iter = self.iter
for h in self._hooks:
h.after_train()
def before_step(self):
# Maintain the invariant that storage.iter == trainer.iter
# for the entire execution of each step
self.storage.iter = self.iter
for h in self._hooks:
h.before_step()
def after_backward(self):
for h in self._hooks:
h.after_backward()
def after_step(self):
for h in self._hooks:
h.after_step()
def run_step(self):
raise NotImplementedError
def state_dict(self):
ret = {"iteration": self.iter}
hooks_state = {}
for h in self._hooks:
sd = h.state_dict()
if sd:
name = type(h).__qualname__
if name in hooks_state:
# TODO handle repetitive stateful hooks
continue
hooks_state[name] = sd
if hooks_state:
ret["hooks"] = hooks_state
return ret
def load_state_dict(self, state_dict):
logger = logging.getLogger(__name__)
self.iter = state_dict["iteration"]
for key, value in state_dict.get("hooks", {}).items():
for h in self._hooks:
try:
name = type(h).__qualname__
except AttributeError:
continue
if name == key:
h.load_state_dict(value)
break
else:
logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.")
class SimpleTrainer(TrainerBase):
"""
A simple trainer for the most common type of task:
single-cost single-optimizer single-data-source iterative optimization,
optionally using data-parallelism.
It assumes that every step, you:
1. Compute the loss with a data from the data_loader.
2. Compute the gradients with the above loss.
3. Update the model with the optimizer.
All other tasks during training (checkpointing, logging, evaluation, LR schedule)
are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`.
If you want to do anything fancier than this,
either subclass TrainerBase and implement your own `run_step`,
or write your own training loop.
"""
def __init__(
self,
model,
data_loader,
optimizer,
gather_metric_period=1,
zero_grad_before_forward=False,
async_write_metrics=False,
):
"""
Args:
model: a torch Module. Takes a data from data_loader and returns a
dict of losses.
data_loader: an iterable. Contains data to be used to call model.
optimizer: a torch optimizer.
gather_metric_period: an int. Every gather_metric_period iterations
the metrics are gathered from all the ranks to rank 0 and logged.
zero_grad_before_forward: whether to zero the gradients before the forward.
async_write_metrics: bool. If True, then write metrics asynchronously to improve
training speed
"""
super().__init__()
"""
We set the model to training mode in the trainer.
However it's valid to train a model that's in eval mode.
If you want your model (or a submodule of it) to behave
like evaluation during training, you can overwrite its train() method.
"""
model.train()
self.model = model
self.data_loader = data_loader
# to access the data loader iterator, call `self._data_loader_iter`
self._data_loader_iter_obj = None
self.optimizer = optimizer
self.gather_metric_period = gather_metric_period
self.zero_grad_before_forward = zero_grad_before_forward
self.async_write_metrics = async_write_metrics
# create a thread pool that can execute non critical logic in run_step asynchronically
# use only 1 worker so tasks will be executred in order of submitting.
self.concurrent_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def run_step(self):
"""
Implement the standard training logic described above.
"""
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
"""
If you want to do something with the data, you can wrap the dataloader.
"""
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
if self.zero_grad_before_forward:
"""
If you need to accumulate gradients or do something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()
"""
If you want to do something with the losses, you can wrap the model.
"""
loss_dict = self.model(data)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
if not self.zero_grad_before_forward:
"""
If you need to accumulate gradients or do something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()
losses.backward()
self.after_backward()
if self.async_write_metrics:
# write metrics asynchronically
self.concurrent_executor.submit(
self._write_metrics, loss_dict, data_time, iter=self.iter
)
else:
self._write_metrics(loss_dict, data_time)
"""
If you need gradient clipping/scaling or other processing, you can
wrap the optimizer with your custom `step()` method. But it is
suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4
"""
self.optimizer.step()
@property
def _data_loader_iter(self):
# only create the data loader iterator when it is used
if self._data_loader_iter_obj is None:
self._data_loader_iter_obj = iter(self.data_loader)
return self._data_loader_iter_obj
def reset_data_loader(self, data_loader_builder):
"""
Delete and replace the current data loader with a new one, which will be created
by calling `data_loader_builder` (without argument).
"""
del self.data_loader
data_loader = data_loader_builder()
self.data_loader = data_loader
self._data_loader_iter_obj = None
def _write_metrics(
self,
loss_dict: Mapping[str, torch.Tensor],
data_time: float,
prefix: str = "",
iter: Optional[int] = None,
) -> None:
logger = logging.getLogger(__name__)
iter = self.iter if iter is None else iter
if (iter + 1) % self.gather_metric_period == 0:
try:
SimpleTrainer.write_metrics(loss_dict, data_time, iter, prefix)
except Exception:
logger.exception("Exception in writing metrics: ")
raise
@staticmethod
def write_metrics(
loss_dict: Mapping[str, torch.Tensor],
data_time: float,
cur_iter: int,
prefix: str = "",
) -> None:
"""
Args:
loss_dict (dict): dict of scalar losses
data_time (float): time taken by the dataloader iteration
prefix (str): prefix for logging keys
"""
metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()}
metrics_dict["data_time"] = data_time
storage = get_event_storage()
# Keep track of data time per rank
storage.put_scalar("rank_data_time", data_time, cur_iter=cur_iter)
# Gather metrics among all workers for logging
# This assumes we do DDP-style training, which is currently the only
# supported method in detectron2.
all_metrics_dict = comm.gather(metrics_dict)
if comm.is_main_process():
# data_time among workers can have high variance. The actual latency
# caused by data_time is the maximum among workers.
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
storage.put_scalar("data_time", data_time, cur_iter=cur_iter)
# average the rest metrics
metrics_dict = {
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
}
total_losses_reduced = sum(metrics_dict.values())
if not np.isfinite(total_losses_reduced):
raise FloatingPointError(
f"Loss became infinite or NaN at iteration={cur_iter}!\n"
f"loss_dict = {metrics_dict}"
)
storage.put_scalar(
"{}total_loss".format(prefix), total_losses_reduced, cur_iter=cur_iter
)
if len(metrics_dict) > 1:
storage.put_scalars(cur_iter=cur_iter, **metrics_dict)
def state_dict(self):
ret = super().state_dict()
ret["optimizer"] = self.optimizer.state_dict()
return ret
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.optimizer.load_state_dict(state_dict["optimizer"])
def after_train(self):
super().after_train()
self.concurrent_executor.shutdown(wait=True)
class AMPTrainer(SimpleTrainer):
"""
Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision
in the training loop.
"""
def __init__(
self,
model,
data_loader,
optimizer,
gather_metric_period=1,
zero_grad_before_forward=False,
grad_scaler=None,
precision: torch.dtype = torch.float16,
log_grad_scaler: bool = False,
async_write_metrics=False,
):
"""
Args:
model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward,
async_write_metrics: same as in :class:`SimpleTrainer`.
grad_scaler: torch GradScaler to automatically scale gradients.
precision: torch.dtype as the target precision to cast to in computations
"""
unsupported = "AMPTrainer does not support single-process multi-device training!"
if isinstance(model, DistributedDataParallel):
assert not (model.device_ids and len(model.device_ids) > 1), unsupported
assert not isinstance(model, DataParallel), unsupported
super().__init__(
model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward
)
if grad_scaler is None:
from torch.cuda.amp import GradScaler
grad_scaler = GradScaler()
self.grad_scaler = grad_scaler
self.precision = precision
self.log_grad_scaler = log_grad_scaler
def run_step(self):
"""
Implement the AMP training logic.
"""
assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
from torch.cuda.amp import autocast
start = time.perf_counter()
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
if self.zero_grad_before_forward:
self.optimizer.zero_grad()
with autocast(dtype=self.precision):
loss_dict = self.model(data)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
if not self.zero_grad_before_forward:
self.optimizer.zero_grad()
self.grad_scaler.scale(losses).backward()
if self.log_grad_scaler:
storage = get_event_storage()
storage.put_scalar("[metric]grad_scaler", self.grad_scaler.get_scale())
self.after_backward()
if self.async_write_metrics:
# write metrics asynchronically
self.concurrent_executor.submit(
self._write_metrics, loss_dict, data_time, iter=self.iter
)
else:
self._write_metrics(loss_dict, data_time)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
def state_dict(self):
ret = super().state_dict()
ret["grad_scaler"] = self.grad_scaler.state_dict()
return ret
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
# Copyright (c) Facebook, Inc. and its affiliates.
from .cityscapes_evaluation import CityscapesInstanceEvaluator, CityscapesSemSegEvaluator
from .coco_evaluation import COCOEvaluator
from .rotated_coco_evaluation import RotatedCOCOEvaluator
from .evaluator import DatasetEvaluator, DatasetEvaluators, inference_context, inference_on_dataset
from .lvis_evaluation import LVISEvaluator
from .panoptic_evaluation import COCOPanopticEvaluator
from .pascal_voc_evaluation import PascalVOCDetectionEvaluator
from .sem_seg_evaluation import SemSegEvaluator
from .testing import print_csv_format, verify_results
__all__ = [k for k in globals().keys() if not k.startswith("_")]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment