Commit 3144257c authored by mashun1's avatar mashun1
Browse files

catvton

parents
Pipeline #1744 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import logging
import os
from typing import Any, Dict, Iterable, List, Optional
from fvcore.common.timer import Timer
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets.lvis import get_lvis_instances_meta
from detectron2.structures import BoxMode
from detectron2.utils.file_io import PathManager
from ..utils import maybe_prepend_base_path
from .coco import (
DENSEPOSE_ALL_POSSIBLE_KEYS,
DENSEPOSE_METADATA_URL_PREFIX,
CocoDatasetInfo,
get_metadata,
)
DATASETS = [
CocoDatasetInfo(
name="densepose_lvis_v1_ds1_train_v1",
images_root="coco_",
annotations_fpath="lvis/densepose_lvis_v1_ds1_train_v1.json",
),
CocoDatasetInfo(
name="densepose_lvis_v1_ds1_val_v1",
images_root="coco_",
annotations_fpath="lvis/densepose_lvis_v1_ds1_val_v1.json",
),
CocoDatasetInfo(
name="densepose_lvis_v1_ds2_train_v1",
images_root="coco_",
annotations_fpath="lvis/densepose_lvis_v1_ds2_train_v1.json",
),
CocoDatasetInfo(
name="densepose_lvis_v1_ds2_val_v1",
images_root="coco_",
annotations_fpath="lvis/densepose_lvis_v1_ds2_val_v1.json",
),
CocoDatasetInfo(
name="densepose_lvis_v1_ds1_val_animals_100",
images_root="coco_",
annotations_fpath="lvis/densepose_lvis_v1_val_animals_100_v2.json",
),
]
def _load_lvis_annotations(json_file: str):
"""
Load COCO annotations from a JSON file
Args:
json_file: str
Path to the file to load annotations from
Returns:
Instance of `pycocotools.coco.COCO` that provides access to annotations
data
"""
from lvis import LVIS
json_file = PathManager.get_local_path(json_file)
logger = logging.getLogger(__name__)
timer = Timer()
lvis_api = LVIS(json_file)
if timer.seconds() > 1:
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
return lvis_api
def _add_categories_metadata(dataset_name: str) -> None:
metadict = get_lvis_instances_meta(dataset_name)
categories = metadict["thing_classes"]
metadata = MetadataCatalog.get(dataset_name)
metadata.categories = {i + 1: categories[i] for i in range(len(categories))}
logger = logging.getLogger(__name__)
logger.info(f"Dataset {dataset_name} has {len(categories)} categories")
def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]) -> None:
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
json_file
)
def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
if "bbox" not in ann_dict:
return
obj["bbox"] = ann_dict["bbox"]
obj["bbox_mode"] = BoxMode.XYWH_ABS
def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
if "segmentation" not in ann_dict:
return
segm = ann_dict["segmentation"]
if not isinstance(segm, dict):
# filter out invalid polygons (< 3 points)
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
if len(segm) == 0:
return
obj["segmentation"] = segm
def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
if "keypoints" not in ann_dict:
return
keypts = ann_dict["keypoints"] # list[int]
for idx, v in enumerate(keypts):
if idx % 3 != 2:
# COCO's segmentation coordinates are floating points in [0, H or W],
# but keypoint coordinates are integers in [0, H-1 or W-1]
# Therefore we assume the coordinates are "pixel indices" and
# add 0.5 to convert to floating point coordinates.
keypts[idx] = v + 0.5
obj["keypoints"] = keypts
def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
if key in ann_dict:
obj[key] = ann_dict[key]
def _combine_images_with_annotations(
dataset_name: str,
image_root: str,
img_datas: Iterable[Dict[str, Any]],
ann_datas: Iterable[Iterable[Dict[str, Any]]],
):
dataset_dicts = []
def get_file_name(img_root, img_dict):
# Determine the path including the split folder ("train2017", "val2017", "test2017") from
# the coco_url field. Example:
# 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
return os.path.join(img_root + split_folder, file_name)
for img_dict, ann_dicts in zip(img_datas, ann_datas):
record = {}
record["file_name"] = get_file_name(image_root, img_dict)
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
record["image_id"] = img_dict["id"]
record["dataset"] = dataset_name
objs = []
for ann_dict in ann_dicts:
assert ann_dict["image_id"] == record["image_id"]
obj = {}
_maybe_add_bbox(obj, ann_dict)
obj["iscrowd"] = ann_dict.get("iscrowd", 0)
obj["category_id"] = ann_dict["category_id"]
_maybe_add_segm(obj, ann_dict)
_maybe_add_keypoints(obj, ann_dict)
_maybe_add_densepose(obj, ann_dict)
objs.append(obj)
record["annotations"] = objs
dataset_dicts.append(record)
return dataset_dicts
def load_lvis_json(annotations_json_file: str, image_root: str, dataset_name: str):
"""
Loads a JSON file with annotations in LVIS instances format.
Replaces `detectron2.data.datasets.coco.load_lvis_json` to handle metadata
in a more flexible way. Postpones category mapping to a later stage to be
able to combine several datasets with different (but coherent) sets of
categories.
Args:
annotations_json_file: str
Path to the JSON file with annotations in COCO instances format.
image_root: str
directory that contains all the images
dataset_name: str
the name that identifies a dataset, e.g. "densepose_coco_2014_train"
extra_annotation_keys: Optional[List[str]]
If provided, these keys are used to extract additional data from
the annotations.
"""
lvis_api = _load_lvis_annotations(PathManager.get_local_path(annotations_json_file))
_add_categories_metadata(dataset_name)
# sort indices for reproducible results
img_ids = sorted(lvis_api.imgs.keys())
# imgs is a list of dicts, each looks something like:
# {'license': 4,
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
# 'file_name': 'COCO_val2014_000000001268.jpg',
# 'height': 427,
# 'width': 640,
# 'date_captured': '2013-11-17 05:57:24',
# 'id': 1268}
imgs = lvis_api.load_imgs(img_ids)
logger = logging.getLogger(__name__)
logger.info("Loaded {} images in LVIS format from {}".format(len(imgs), annotations_json_file))
# anns is a list[list[dict]], where each dict is an annotation
# record for an object. The inner list enumerates the objects in an image
# and the outer list enumerates over images.
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
_verify_annotations_have_unique_ids(annotations_json_file, anns)
dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
return dataset_records
def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None) -> None:
"""
Registers provided LVIS DensePose dataset
Args:
dataset_data: CocoDatasetInfo
Dataset data
datasets_root: Optional[str]
Datasets root folder (default: None)
"""
annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
def load_annotations():
return load_lvis_json(
annotations_json_file=annotations_fpath,
image_root=images_root,
dataset_name=dataset_data.name,
)
DatasetCatalog.register(dataset_data.name, load_annotations)
MetadataCatalog.get(dataset_data.name).set(
json_file=annotations_fpath,
image_root=images_root,
evaluator_type="lvis",
**get_metadata(DENSEPOSE_METADATA_URL_PREFIX),
)
def register_datasets(
datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
) -> None:
"""
Registers provided LVIS DensePose datasets
Args:
datasets_data: Iterable[CocoDatasetInfo]
An iterable of dataset datas
datasets_root: Optional[str]
Datasets root folder (default: None)
"""
for dataset_data in datasets_data:
register_dataset(dataset_data, datasets_root)
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import logging
import numpy as np
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch.utils.data.dataset import Dataset
from detectron2.data.detection_utils import read_image
ImageTransform = Callable[[torch.Tensor], torch.Tensor]
class ImageListDataset(Dataset):
"""
Dataset that provides images from a list.
"""
_EMPTY_IMAGE = torch.empty((0, 3, 1, 1))
def __init__(
self,
image_list: List[str],
category_list: Union[str, List[str], None] = None,
transform: Optional[ImageTransform] = None,
):
"""
Args:
image_list (List[str]): list of paths to image files
category_list (Union[str, List[str], None]): list of animal categories for
each image. If it is a string, or None, this applies to all images
"""
if type(category_list) is list:
self.category_list = category_list
else:
self.category_list = [category_list] * len(image_list)
assert len(image_list) == len(
self.category_list
), "length of image and category lists must be equal"
self.image_list = image_list
self.transform = transform
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Gets selected images from the list
Args:
idx (int): video index in the video list file
Returns:
A dictionary containing two keys:
images (torch.Tensor): tensor of size [N, 3, H, W] (N = 1, or 0 for _EMPTY_IMAGE)
categories (List[str]): categories of the frames
"""
categories = [self.category_list[idx]]
fpath = self.image_list[idx]
transform = self.transform
try:
image = torch.from_numpy(np.ascontiguousarray(read_image(fpath, format="BGR")))
image = image.permute(2, 0, 1).unsqueeze(0).float() # HWC -> NCHW
if transform is not None:
image = transform(image)
return {"images": image, "categories": categories}
except (OSError, RuntimeError) as e:
logger = logging.getLogger(__name__)
logger.warning(f"Error opening image file container {fpath}: {e}")
return {"images": self._EMPTY_IMAGE, "categories": []}
def __len__(self):
return len(self.image_list)
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import random
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
from torch import nn
SampledData = Any
ModelOutput = Any
def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]:
"""
Group elements of an iterable by chunks of size `n`, e.g.
grouper(range(9), 4) ->
(0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None)
"""
it = iter(iterable)
while True:
values = []
for _ in range(n):
try:
value = next(it)
except StopIteration:
if values:
values.extend([fillvalue] * (n - len(values)))
yield tuple(values)
return
values.append(value)
yield tuple(values)
class ScoreBasedFilter:
"""
Filters entries in model output based on their scores
Discards all entries with score less than the specified minimum
"""
def __init__(self, min_score: float = 0.8):
self.min_score = min_score
def __call__(self, model_output: ModelOutput) -> ModelOutput:
for model_output_i in model_output:
instances = model_output_i["instances"]
if not instances.has("scores"):
continue
instances_filtered = instances[instances.scores >= self.min_score]
model_output_i["instances"] = instances_filtered
return model_output
class InferenceBasedLoader:
"""
Data loader based on results inferred by a model. Consists of:
- a data loader that provides batches of images
- a model that is used to infer the results
- a data sampler that converts inferred results to annotations
"""
def __init__(
self,
model: nn.Module,
data_loader: Iterable[List[Dict[str, Any]]],
data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None,
data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None,
shuffle: bool = True,
batch_size: int = 4,
inference_batch_size: int = 4,
drop_last: bool = False,
category_to_class_mapping: Optional[dict] = None,
):
"""
Constructor
Args:
model (torch.nn.Module): model used to produce data
data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides
dictionaries with "images" and "categories" fields to perform inference on
data_sampler (Callable: ModelOutput -> SampledData): functor
that produces annotation data from inference results;
(optional, default: None)
data_filter (Callable: ModelOutput -> ModelOutput): filter
that selects model outputs for further processing
(optional, default: None)
shuffle (bool): if True, the input images get shuffled
batch_size (int): batch size for the produced annotation data
inference_batch_size (int): batch size for input images
drop_last (bool): if True, drop the last batch if it is undersized
category_to_class_mapping (dict): category to class mapping
"""
self.model = model
self.model.eval()
self.data_loader = data_loader
self.data_sampler = data_sampler
self.data_filter = data_filter
self.shuffle = shuffle
self.batch_size = batch_size
self.inference_batch_size = inference_batch_size
self.drop_last = drop_last
if category_to_class_mapping is not None:
self.category_to_class_mapping = category_to_class_mapping
else:
self.category_to_class_mapping = {}
def __iter__(self) -> Iterator[List[SampledData]]:
for batch in self.data_loader:
# batch : List[Dict[str: Tensor[N, C, H, W], str: Optional[str]]]
# images_batch : Tensor[N, C, H, W]
# image : Tensor[C, H, W]
images_and_categories = [
{"image": image, "category": category}
for element in batch
for image, category in zip(element["images"], element["categories"])
]
if not images_and_categories:
continue
if self.shuffle:
random.shuffle(images_and_categories)
yield from self._produce_data(images_and_categories) # pyre-ignore[6]
def _produce_data(
self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]]
) -> Iterator[List[SampledData]]:
"""
Produce batches of data from images
Args:
images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]):
list of images and corresponding categories to process
Returns:
Iterator over batches of data sampled from model outputs
"""
data_batches: List[SampledData] = []
category_to_class_mapping = self.category_to_class_mapping
batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size)
for batch in batched_images_and_categories:
batch = [
{
"image": image_and_category["image"].to(self.model.device),
"category": image_and_category["category"],
}
for image_and_category in batch
if image_and_category is not None
]
if not batch:
continue
with torch.no_grad():
model_output = self.model(batch)
for model_output_i, batch_i in zip(model_output, batch):
assert len(batch_i["image"].shape) == 3
model_output_i["image"] = batch_i["image"]
instance_class = category_to_class_mapping.get(batch_i["category"], 0)
model_output_i["instances"].dataset_classes = torch.tensor(
[instance_class] * len(model_output_i["instances"])
)
model_output_filtered = (
model_output if self.data_filter is None else self.data_filter(model_output)
)
data = (
model_output_filtered
if self.data_sampler is None
else self.data_sampler(model_output_filtered)
)
for data_i in data:
if len(data_i["instances"]):
data_batches.append(data_i)
if len(data_batches) >= self.batch_size:
yield data_batches[: self.batch_size]
data_batches = data_batches[self.batch_size :]
if not self.drop_last and data_batches:
yield data_batches
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from . import builtin
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from .catalog import MeshInfo, register_meshes
DENSEPOSE_MESHES_DIR = "https://dl.fbaipublicfiles.com/densepose/meshes/"
MESHES = [
MeshInfo(
name="smpl_27554",
data="smpl_27554.pkl",
geodists="geodists/geodists_smpl_27554.pkl",
symmetry="symmetry/symmetry_smpl_27554.pkl",
texcoords="texcoords/texcoords_smpl_27554.pkl",
),
MeshInfo(
name="chimp_5029",
data="chimp_5029.pkl",
geodists="geodists/geodists_chimp_5029.pkl",
symmetry="symmetry/symmetry_chimp_5029.pkl",
texcoords="texcoords/texcoords_chimp_5029.pkl",
),
MeshInfo(
name="cat_5001",
data="cat_5001.pkl",
geodists="geodists/geodists_cat_5001.pkl",
symmetry="symmetry/symmetry_cat_5001.pkl",
texcoords="texcoords/texcoords_cat_5001.pkl",
),
MeshInfo(
name="cat_7466",
data="cat_7466.pkl",
geodists="geodists/geodists_cat_7466.pkl",
symmetry="symmetry/symmetry_cat_7466.pkl",
texcoords="texcoords/texcoords_cat_7466.pkl",
),
MeshInfo(
name="sheep_5004",
data="sheep_5004.pkl",
geodists="geodists/geodists_sheep_5004.pkl",
symmetry="symmetry/symmetry_sheep_5004.pkl",
texcoords="texcoords/texcoords_sheep_5004.pkl",
),
MeshInfo(
name="zebra_5002",
data="zebra_5002.pkl",
geodists="geodists/geodists_zebra_5002.pkl",
symmetry="symmetry/symmetry_zebra_5002.pkl",
texcoords="texcoords/texcoords_zebra_5002.pkl",
),
MeshInfo(
name="horse_5004",
data="horse_5004.pkl",
geodists="geodists/geodists_horse_5004.pkl",
symmetry="symmetry/symmetry_horse_5004.pkl",
texcoords="texcoords/texcoords_zebra_5002.pkl",
),
MeshInfo(
name="giraffe_5002",
data="giraffe_5002.pkl",
geodists="geodists/geodists_giraffe_5002.pkl",
symmetry="symmetry/symmetry_giraffe_5002.pkl",
texcoords="texcoords/texcoords_giraffe_5002.pkl",
),
MeshInfo(
name="elephant_5002",
data="elephant_5002.pkl",
geodists="geodists/geodists_elephant_5002.pkl",
symmetry="symmetry/symmetry_elephant_5002.pkl",
texcoords="texcoords/texcoords_elephant_5002.pkl",
),
MeshInfo(
name="dog_5002",
data="dog_5002.pkl",
geodists="geodists/geodists_dog_5002.pkl",
symmetry="symmetry/symmetry_dog_5002.pkl",
texcoords="texcoords/texcoords_dog_5002.pkl",
),
MeshInfo(
name="dog_7466",
data="dog_7466.pkl",
geodists="geodists/geodists_dog_7466.pkl",
symmetry="symmetry/symmetry_dog_7466.pkl",
texcoords="texcoords/texcoords_dog_7466.pkl",
),
MeshInfo(
name="cow_5002",
data="cow_5002.pkl",
geodists="geodists/geodists_cow_5002.pkl",
symmetry="symmetry/symmetry_cow_5002.pkl",
texcoords="texcoords/texcoords_cow_5002.pkl",
),
MeshInfo(
name="bear_4936",
data="bear_4936.pkl",
geodists="geodists/geodists_bear_4936.pkl",
symmetry="symmetry/symmetry_bear_4936.pkl",
texcoords="texcoords/texcoords_bear_4936.pkl",
),
]
register_meshes(MESHES, DENSEPOSE_MESHES_DIR)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
import logging
from collections import UserDict
from dataclasses import dataclass
from typing import Iterable, Optional
from ..utils import maybe_prepend_base_path
@dataclass
class MeshInfo:
name: str
data: str
geodists: Optional[str] = None
symmetry: Optional[str] = None
texcoords: Optional[str] = None
class _MeshCatalog(UserDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mesh_ids = {}
self.mesh_names = {}
self.max_mesh_id = -1
def __setitem__(self, key, value):
if key in self:
logger = logging.getLogger(__name__)
logger.warning(
f"Overwriting mesh catalog entry '{key}': old value {self[key]}"
f", new value {value}"
)
mesh_id = self.mesh_ids[key]
else:
self.max_mesh_id += 1
mesh_id = self.max_mesh_id
super().__setitem__(key, value)
self.mesh_ids[key] = mesh_id
self.mesh_names[mesh_id] = key
def get_mesh_id(self, shape_name: str) -> int:
return self.mesh_ids[shape_name]
def get_mesh_name(self, mesh_id: int) -> str:
return self.mesh_names[mesh_id]
MeshCatalog = _MeshCatalog()
def register_mesh(mesh_info: MeshInfo, base_path: Optional[str]) -> None:
geodists, symmetry, texcoords = mesh_info.geodists, mesh_info.symmetry, mesh_info.texcoords
if geodists:
geodists = maybe_prepend_base_path(base_path, geodists)
if symmetry:
symmetry = maybe_prepend_base_path(base_path, symmetry)
if texcoords:
texcoords = maybe_prepend_base_path(base_path, texcoords)
MeshCatalog[mesh_info.name] = MeshInfo(
name=mesh_info.name,
data=maybe_prepend_base_path(base_path, mesh_info.data),
geodists=geodists,
symmetry=symmetry,
texcoords=texcoords,
)
def register_meshes(mesh_infos: Iterable[MeshInfo], base_path: Optional[str]) -> None:
for mesh_info in mesh_infos:
register_mesh(mesh_info, base_path)
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from .densepose_uniform import DensePoseUniformSampler
from .densepose_confidence_based import DensePoseConfidenceBasedSampler
from .densepose_cse_uniform import DensePoseCSEUniformSampler
from .densepose_cse_confidence_based import DensePoseCSEConfidenceBasedSampler
from .mask_from_densepose import MaskFromDensePoseSampler
from .prediction_to_gt import PredictionToGroundTruthSampler
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from typing import Any, Dict, List, Tuple
import torch
from torch.nn import functional as F
from detectron2.structures import BoxMode, Instances
from densepose.converters import ToChartResultConverter
from densepose.converters.base import IntTupleBox, make_int_box
from densepose.structures import DensePoseDataRelative, DensePoseList
class DensePoseBaseSampler:
"""
Base DensePose sampler to produce DensePose data from DensePose predictions.
Samples for each class are drawn according to some distribution over all pixels estimated
to belong to that class.
"""
def __init__(self, count_per_class: int = 8):
"""
Constructor
Args:
count_per_class (int): the sampler produces at most `count_per_class`
samples for each category
"""
self.count_per_class = count_per_class
def __call__(self, instances: Instances) -> DensePoseList:
"""
Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`)
into DensePose annotations data (an instance of `DensePoseList`)
"""
boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu()
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
dp_datas = []
for i in range(len(boxes_xywh_abs)):
annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i]))
annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( # pyre-ignore[6]
instances[i].pred_densepose
)
dp_datas.append(DensePoseDataRelative(annotation_i))
# create densepose annotations on CPU
dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size)
return dp_list
def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
"""
Sample DensPoseDataRelative from estimation results
"""
labels, dp_result = self._produce_labels_and_results(instance)
annotation = {
DensePoseDataRelative.X_KEY: [],
DensePoseDataRelative.Y_KEY: [],
DensePoseDataRelative.U_KEY: [],
DensePoseDataRelative.V_KEY: [],
DensePoseDataRelative.I_KEY: [],
}
n, h, w = dp_result.shape
for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1):
# indices - tuple of 3 1D tensors of size k
# 0: index along the first dimension N
# 1: index along H dimension
# 2: index along W dimension
indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True)
# values - an array of size [n, k]
# n: number of channels (U, V, confidences)
# k: number of points labeled with part_id
values = dp_result[indices].view(n, -1)
k = values.shape[1]
count = min(self.count_per_class, k)
if count <= 0:
continue
index_sample = self._produce_index_sample(values, count)
sampled_values = values[:, index_sample]
sampled_y = indices[1][index_sample] + 0.5
sampled_x = indices[2][index_sample] + 0.5
# prepare / normalize data
x = (sampled_x / w * 256.0).cpu().tolist()
y = (sampled_y / h * 256.0).cpu().tolist()
u = sampled_values[0].clamp(0, 1).cpu().tolist()
v = sampled_values[1].clamp(0, 1).cpu().tolist()
fine_segm_labels = [part_id] * count
# extend annotations
annotation[DensePoseDataRelative.X_KEY].extend(x)
annotation[DensePoseDataRelative.Y_KEY].extend(y)
annotation[DensePoseDataRelative.U_KEY].extend(u)
annotation[DensePoseDataRelative.V_KEY].extend(v)
annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels)
return annotation
def _produce_index_sample(self, values: torch.Tensor, count: int):
"""
Abstract method to produce a sample of indices to select data
To be implemented in descendants
Args:
values (torch.Tensor): an array of size [n, k] that contains
estimated values (U, V, confidences);
n: number of channels (U, V, confidences)
k: number of points labeled with part_id
count (int): number of samples to produce, should be positive and <= k
Return:
list(int): indices of values (along axis 1) selected as a sample
"""
raise NotImplementedError
def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Method to get labels and DensePose results from an instance
Args:
instance (Instances): an instance of `DensePoseChartPredictorOutput`
Return:
labels (torch.Tensor): shape [H, W], DensePose segmentation labels
dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v
"""
converter = ToChartResultConverter
chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
return labels, dp_result
def _resample_mask(self, output: Any) -> torch.Tensor:
"""
Convert DensePose predictor output to segmentation annotation - tensors of size
(256, 256) and type `int64`.
Args:
output: DensePose predictor output with the following attributes:
- coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
segmentation scores
- fine_segm: tensor of size [N, C, H, W] with unnormalized fine
segmentation scores
Return:
Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
where S = DensePoseDataRelative.MASK_SIZE
"""
sz = DensePoseDataRelative.MASK_SIZE
S = (
F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
.argmax(dim=1)
.long()
)
I = (
(
F.interpolate(
output.fine_segm,
(sz, sz),
mode="bilinear",
align_corners=False,
).argmax(dim=1)
* (S > 0).long()
)
.squeeze()
.cpu()
)
# Map fine segmentation results to coarse segmentation ground truth
# TODO: extract this into separate classes
# coarse segmentation: 1 = Torso, 2 = Right Hand, 3 = Left Hand,
# 4 = Left Foot, 5 = Right Foot, 6 = Upper Leg Right, 7 = Upper Leg Left,
# 8 = Lower Leg Right, 9 = Lower Leg Left, 10 = Upper Arm Left,
# 11 = Upper Arm Right, 12 = Lower Arm Left, 13 = Lower Arm Right,
# 14 = Head
# fine segmentation: 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand,
# 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right,
# 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right,
# 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left,
# 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left,
# 20, 22 = Lower Arm Right, 23, 24 = Head
FINE_TO_COARSE_SEGMENTATION = {
1: 1,
2: 1,
3: 2,
4: 3,
5: 4,
6: 5,
7: 6,
8: 7,
9: 6,
10: 7,
11: 8,
12: 9,
13: 8,
14: 9,
15: 10,
16: 11,
17: 10,
18: 11,
19: 12,
20: 13,
21: 12,
22: 13,
23: 14,
24: 14,
}
mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu"))
for i in range(DensePoseDataRelative.N_PART_LABELS):
mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1]
return mask
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import random
from typing import Optional, Tuple
import torch
from densepose.converters import ToChartResultConverterWithConfidences
from .densepose_base import DensePoseBaseSampler
class DensePoseConfidenceBasedSampler(DensePoseBaseSampler):
"""
Samples DensePose data from DensePose predictions.
Samples for each class are drawn using confidence value estimates.
"""
def __init__(
self,
confidence_channel: str,
count_per_class: int = 8,
search_count_multiplier: Optional[float] = None,
search_proportion: Optional[float] = None,
):
"""
Constructor
Args:
confidence_channel (str): confidence channel to use for sampling;
possible values:
"sigma_2": confidences for UV values
"fine_segm_confidence": confidences for fine segmentation
"coarse_segm_confidence": confidences for coarse segmentation
(default: "sigma_2")
count_per_class (int): the sampler produces at most `count_per_class`
samples for each category (default: 8)
search_count_multiplier (float or None): if not None, the total number
of the most confident estimates of a given class to consider is
defined as `min(search_count_multiplier * count_per_class, N)`,
where `N` is the total number of estimates of the class; cannot be
specified together with `search_proportion` (default: None)
search_proportion (float or None): if not None, the total number of the
of the most confident estimates of a given class to consider is
defined as `min(max(search_proportion * N, count_per_class), N)`,
where `N` is the total number of estimates of the class; cannot be
specified together with `search_count_multiplier` (default: None)
"""
super().__init__(count_per_class)
self.confidence_channel = confidence_channel
self.search_count_multiplier = search_count_multiplier
self.search_proportion = search_proportion
assert (search_count_multiplier is None) or (search_proportion is None), (
f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
f"and search_proportion (={search_proportion})"
)
def _produce_index_sample(self, values: torch.Tensor, count: int):
"""
Produce a sample of indices to select data based on confidences
Args:
values (torch.Tensor): an array of size [n, k] that contains
estimated values (U, V, confidences);
n: number of channels (U, V, confidences)
k: number of points labeled with part_id
count (int): number of samples to produce, should be positive and <= k
Return:
list(int): indices of values (along axis 1) selected as a sample
"""
k = values.shape[1]
if k == count:
index_sample = list(range(k))
else:
# take the best count * search_count_multiplier pixels,
# sample from them uniformly
# (here best = smallest variance)
_, sorted_confidence_indices = torch.sort(values[2])
if self.search_count_multiplier is not None:
search_count = min(int(count * self.search_count_multiplier), k)
elif self.search_proportion is not None:
search_count = min(max(int(k * self.search_proportion), count), k)
else:
search_count = min(count, k)
sample_from_top = random.sample(range(search_count), count)
index_sample = sorted_confidence_indices[:search_count][sample_from_top]
return index_sample
def _produce_labels_and_results(self, instance) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Method to get labels and DensePose results from an instance, with confidences
Args:
instance (Instances): an instance of `DensePoseChartPredictorOutputWithConfidences`
Return:
labels (torch.Tensor): shape [H, W], DensePose segmentation labels
dp_result (torch.Tensor): shape [3, H, W], DensePose results u and v
stacked with the confidence channel
"""
converter = ToChartResultConverterWithConfidences
chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
dp_result = torch.cat(
(dp_result, getattr(chart_result, self.confidence_channel)[None].cpu())
)
return labels, dp_result
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from typing import Any, Dict, List, Tuple
import torch
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from densepose.converters.base import IntTupleBox
from densepose.data.utils import get_class_to_mesh_name_mapping
from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
from densepose.structures import DensePoseDataRelative
from .densepose_base import DensePoseBaseSampler
class DensePoseCSEBaseSampler(DensePoseBaseSampler):
"""
Base DensePose sampler to produce DensePose data from DensePose predictions.
Samples for each class are drawn according to some distribution over all pixels estimated
to belong to that class.
"""
def __init__(
self,
cfg: CfgNode,
use_gt_categories: bool,
embedder: torch.nn.Module,
count_per_class: int = 8,
):
"""
Constructor
Args:
cfg (CfgNode): the config of the model
embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
count_per_class (int): the sampler produces at most `count_per_class`
samples for each category
"""
super().__init__(count_per_class)
self.embedder = embedder
self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
self.use_gt_categories = use_gt_categories
def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
"""
Sample DensPoseDataRelative from estimation results
"""
if self.use_gt_categories:
instance_class = instance.dataset_classes.tolist()[0]
else:
instance_class = instance.pred_classes.tolist()[0]
mesh_name = self.class_to_mesh_name[instance_class]
annotation = {
DensePoseDataRelative.X_KEY: [],
DensePoseDataRelative.Y_KEY: [],
DensePoseDataRelative.VERTEX_IDS_KEY: [],
DensePoseDataRelative.MESH_NAME_KEY: mesh_name,
}
mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh)
indices = torch.nonzero(mask, as_tuple=True)
selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu()
values = other_values[:, indices[0], indices[1]]
k = values.shape[1]
count = min(self.count_per_class, k)
if count <= 0:
return annotation
index_sample = self._produce_index_sample(values, count)
closest_vertices = squared_euclidean_distance_matrix(
selected_embeddings[index_sample], self.embedder(mesh_name)
)
closest_vertices = torch.argmin(closest_vertices, dim=1)
sampled_y = indices[0][index_sample] + 0.5
sampled_x = indices[1][index_sample] + 0.5
# prepare / normalize data
_, _, w, h = bbox_xywh
x = (sampled_x / w * 256.0).cpu().tolist()
y = (sampled_y / h * 256.0).cpu().tolist()
# extend annotations
annotation[DensePoseDataRelative.X_KEY].extend(x)
annotation[DensePoseDataRelative.Y_KEY].extend(y)
annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist())
return annotation
def _produce_mask_and_results(
self, instance: Instances, bbox_xywh: IntTupleBox
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Method to get labels and DensePose results from an instance
Args:
instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput`
bbox_xywh (IntTupleBox): the corresponding bounding box
Return:
mask (torch.Tensor): shape [H, W], DensePose segmentation mask
embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W],
DensePose CSE Embeddings
other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W],
for potential other values
"""
densepose_output = instance.pred_densepose
S = densepose_output.coarse_segm
E = densepose_output.embedding
_, _, w, h = bbox_xywh
embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0]
coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0]
mask = coarse_segm_resized.argmax(0) > 0
other_values = torch.empty((0, h, w), device=E.device)
return mask, embeddings, other_values
def _resample_mask(self, output: Any) -> torch.Tensor:
"""
Convert DensePose predictor output to segmentation annotation - tensors of size
(256, 256) and type `int64`.
Args:
output: DensePose predictor output with the following attributes:
- coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
segmentation scores
Return:
Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
where S = DensePoseDataRelative.MASK_SIZE
"""
sz = DensePoseDataRelative.MASK_SIZE
mask = (
F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
.argmax(dim=1)
.long()
.squeeze()
.cpu()
)
return mask
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import random
from typing import Optional, Tuple
import torch
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from densepose.converters.base import IntTupleBox
from .densepose_cse_base import DensePoseCSEBaseSampler
class DensePoseCSEConfidenceBasedSampler(DensePoseCSEBaseSampler):
"""
Samples DensePose data from DensePose predictions.
Samples for each class are drawn using confidence value estimates.
"""
def __init__(
self,
cfg: CfgNode,
use_gt_categories: bool,
embedder: torch.nn.Module,
confidence_channel: str,
count_per_class: int = 8,
search_count_multiplier: Optional[float] = None,
search_proportion: Optional[float] = None,
):
"""
Constructor
Args:
cfg (CfgNode): the config of the model
embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
confidence_channel (str): confidence channel to use for sampling;
possible values:
"coarse_segm_confidence": confidences for coarse segmentation
(default: "coarse_segm_confidence")
count_per_class (int): the sampler produces at most `count_per_class`
samples for each category (default: 8)
search_count_multiplier (float or None): if not None, the total number
of the most confident estimates of a given class to consider is
defined as `min(search_count_multiplier * count_per_class, N)`,
where `N` is the total number of estimates of the class; cannot be
specified together with `search_proportion` (default: None)
search_proportion (float or None): if not None, the total number of the
of the most confident estimates of a given class to consider is
defined as `min(max(search_proportion * N, count_per_class), N)`,
where `N` is the total number of estimates of the class; cannot be
specified together with `search_count_multiplier` (default: None)
"""
super().__init__(cfg, use_gt_categories, embedder, count_per_class)
self.confidence_channel = confidence_channel
self.search_count_multiplier = search_count_multiplier
self.search_proportion = search_proportion
assert (search_count_multiplier is None) or (search_proportion is None), (
f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
f"and search_proportion (={search_proportion})"
)
def _produce_index_sample(self, values: torch.Tensor, count: int):
"""
Produce a sample of indices to select data based on confidences
Args:
values (torch.Tensor): a tensor of length k that contains confidences
k: number of points labeled with part_id
count (int): number of samples to produce, should be positive and <= k
Return:
list(int): indices of values (along axis 1) selected as a sample
"""
k = values.shape[1]
if k == count:
index_sample = list(range(k))
else:
# take the best count * search_count_multiplier pixels,
# sample from them uniformly
# (here best = smallest variance)
_, sorted_confidence_indices = torch.sort(values[0])
if self.search_count_multiplier is not None:
search_count = min(int(count * self.search_count_multiplier), k)
elif self.search_proportion is not None:
search_count = min(max(int(k * self.search_proportion), count), k)
else:
search_count = min(count, k)
sample_from_top = random.sample(range(search_count), count)
index_sample = sorted_confidence_indices[-search_count:][sample_from_top]
return index_sample
def _produce_mask_and_results(
self, instance: Instances, bbox_xywh: IntTupleBox
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Method to get labels and DensePose results from an instance
Args:
instance (Instances): an instance of
`DensePoseEmbeddingPredictorOutputWithConfidences`
bbox_xywh (IntTupleBox): the corresponding bounding box
Return:
mask (torch.Tensor): shape [H, W], DensePose segmentation mask
embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W]
DensePose CSE Embeddings
other_values: a tensor of shape [1, H, W], DensePose CSE confidence
"""
_, _, w, h = bbox_xywh
densepose_output = instance.pred_densepose
mask, embeddings, _ = super()._produce_mask_and_results(instance, bbox_xywh)
other_values = F.interpolate(
getattr(densepose_output, self.confidence_channel),
size=(h, w),
mode="bilinear",
)[0].cpu()
return mask, embeddings, other_values
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from .densepose_cse_base import DensePoseCSEBaseSampler
from .densepose_uniform import DensePoseUniformSampler
class DensePoseCSEUniformSampler(DensePoseCSEBaseSampler, DensePoseUniformSampler):
"""
Uniform Sampler for CSE
"""
pass
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import random
import torch
from .densepose_base import DensePoseBaseSampler
class DensePoseUniformSampler(DensePoseBaseSampler):
"""
Samples DensePose data from DensePose predictions.
Samples for each class are drawn uniformly over all pixels estimated
to belong to that class.
"""
def __init__(self, count_per_class: int = 8):
"""
Constructor
Args:
count_per_class (int): the sampler produces at most `count_per_class`
samples for each category
"""
super().__init__(count_per_class)
def _produce_index_sample(self, values: torch.Tensor, count: int):
"""
Produce a uniform sample of indices to select data
Args:
values (torch.Tensor): an array of size [n, k] that contains
estimated values (U, V, confidences);
n: number of channels (U, V, confidences)
k: number of points labeled with part_id
count (int): number of samples to produce, should be positive and <= k
Return:
list(int): indices of values (along axis 1) selected as a sample
"""
k = values.shape[1]
return random.sample(range(k), count)
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from detectron2.structures import BitMasks, Instances
from densepose.converters import ToMaskConverter
class MaskFromDensePoseSampler:
"""
Produce mask GT from DensePose predictions
This sampler simply converts DensePose predictions to BitMasks
that a contain a bool tensor of the size of the input image
"""
def __call__(self, instances: Instances) -> BitMasks:
"""
Converts predicted data from `instances` into the GT mask data
Args:
instances (Instances): predicted results, expected to have `pred_densepose` field
Returns:
Boolean Tensor of the size of the input image that has non-zero
values at pixels that are estimated to belong to the detected object
"""
return ToMaskConverter.convert(
instances.pred_densepose, instances.pred_boxes, instances.image_size
)
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from detectron2.structures import Instances
ModelOutput = Dict[str, Any]
SampledData = Dict[str, Any]
@dataclass
class _Sampler:
"""
Sampler registry entry that contains:
- src (str): source field to sample from (deleted after sampling)
- dst (Optional[str]): destination field to sample to, if not None
- func (Optional[Callable: Any -> Any]): function that performs sampling,
if None, reference copy is performed
"""
src: str
dst: Optional[str]
func: Optional[Callable[[Any], Any]]
class PredictionToGroundTruthSampler:
"""
Sampler implementation that converts predictions to GT using registered
samplers for different fields of `Instances`.
"""
def __init__(self, dataset_name: str = ""):
self.dataset_name = dataset_name
self._samplers = {}
self.register_sampler("pred_boxes", "gt_boxes", None)
self.register_sampler("pred_classes", "gt_classes", None)
# delete scores
self.register_sampler("scores")
def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
"""
Transform model output into ground truth data through sampling
Args:
model_output (Dict[str, Any]): model output
Returns:
Dict[str, Any]: sampled data
"""
for model_output_i in model_output:
instances: Instances = model_output_i["instances"]
# transform data in each field
for _, sampler in self._samplers.items():
if not instances.has(sampler.src) or sampler.dst is None:
continue
if sampler.func is None:
instances.set(sampler.dst, instances.get(sampler.src))
else:
instances.set(sampler.dst, sampler.func(instances))
# delete model output data that was transformed
for _, sampler in self._samplers.items():
if sampler.src != sampler.dst and instances.has(sampler.src):
instances.remove(sampler.src)
model_output_i["dataset"] = self.dataset_name
return model_output
def register_sampler(
self,
prediction_attr: str,
gt_attr: Optional[str] = None,
func: Optional[Callable[[Any], Any]] = None,
):
"""
Register sampler for a field
Args:
prediction_attr (str): field to replace with a sampled value
gt_attr (Optional[str]): field to store the sampled value to, if not None
func (Optional[Callable: Any -> Any]): sampler function
"""
self._samplers[(prediction_attr, gt_attr)] = _Sampler(
src=prediction_attr, dst=gt_attr, func=func
)
def remove_sampler(
self,
prediction_attr: str,
gt_attr: Optional[str] = None,
):
"""
Remove sampler for a field
Args:
prediction_attr (str): field to replace with a sampled value
gt_attr (Optional[str]): field to store the sampled value to, if not None
"""
assert (prediction_attr, gt_attr) in self._samplers
del self._samplers[(prediction_attr, gt_attr)]
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from .image import ImageResizeTransform
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import torch
class ImageResizeTransform:
"""
Transform that resizes images loaded from a dataset
(BGR data in NCHW channel order, typically uint8) to a format ready to be
consumed by DensePose training (BGR float32 data in NCHW channel order)
"""
def __init__(self, min_size: int = 800, max_size: int = 1333):
self.min_size = min_size
self.max_size = max_size
def __call__(self, images: torch.Tensor) -> torch.Tensor:
"""
Args:
images (torch.Tensor): tensor of size [N, 3, H, W] that contains
BGR data (typically in uint8)
Returns:
images (torch.Tensor): tensor of size [N, 3, H1, W1] where
H1 and W1 are chosen to respect the specified min and max sizes
and preserve the original aspect ratio, the data channels
follow BGR order and the data type is `torch.float32`
"""
# resize with min size
images = images.float()
min_size = min(images.shape[-2:])
max_size = max(images.shape[-2:])
scale = min(self.min_size / min_size, self.max_size / max_size)
images = torch.nn.functional.interpolate(
images,
scale_factor=scale,
mode="bilinear",
align_corners=False,
)
return images
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import os
from typing import Dict, Optional
from detectron2.config import CfgNode
def is_relative_local_path(path: str) -> bool:
path_str = os.fsdecode(path)
return ("://" not in path_str) and not os.path.isabs(path)
def maybe_prepend_base_path(base_path: Optional[str], path: str):
"""
Prepends the provided path with a base path prefix if:
1) base path is not None;
2) path is a local path
"""
if base_path is None:
return path
if is_relative_local_path(path):
return os.path.join(base_path, path)
return path
def get_class_to_mesh_name_mapping(cfg: CfgNode) -> Dict[int, str]:
return {
int(class_id): mesh_name
for class_id, mesh_name in cfg.DATASETS.CLASS_TO_MESH_NAME_MAPPING.items()
}
def get_category_to_class_mapping(dataset_cfg: CfgNode) -> Dict[str, int]:
return {
category: int(class_id)
for category, class_id in dataset_cfg.CATEGORY_TO_CLASS_MAPPING.items()
}
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from .frame_selector import (
FrameSelectionStrategy,
RandomKFramesSelector,
FirstKFramesSelector,
LastKFramesSelector,
FrameTsList,
FrameSelector,
)
from .video_keyframe_dataset import (
VideoKeyframeDataset,
video_list_from_file,
list_keyframes,
read_keyframes,
)
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import random
from collections.abc import Callable
from enum import Enum
from typing import Callable as TCallable
from typing import List
FrameTsList = List[int]
FrameSelector = TCallable[[FrameTsList], FrameTsList]
class FrameSelectionStrategy(Enum):
"""
Frame selection strategy used with videos:
- "random_k": select k random frames
- "first_k": select k first frames
- "last_k": select k last frames
- "all": select all frames
"""
# fmt: off
RANDOM_K = "random_k"
FIRST_K = "first_k"
LAST_K = "last_k"
ALL = "all"
# fmt: on
class RandomKFramesSelector(Callable): # pyre-ignore[39]
"""
Selector that retains at most `k` random frames
"""
def __init__(self, k: int):
self.k = k
def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
"""
Select `k` random frames
Args:
frames_tss (List[int]): timestamps of input frames
Returns:
List[int]: timestamps of selected frames
"""
return random.sample(frame_tss, min(self.k, len(frame_tss)))
class FirstKFramesSelector(Callable): # pyre-ignore[39]
"""
Selector that retains at most `k` first frames
"""
def __init__(self, k: int):
self.k = k
def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
"""
Select `k` first frames
Args:
frames_tss (List[int]): timestamps of input frames
Returns:
List[int]: timestamps of selected frames
"""
return frame_tss[: self.k]
class LastKFramesSelector(Callable): # pyre-ignore[39]
"""
Selector that retains at most `k` last frames from video data
"""
def __init__(self, k: int):
self.k = k
def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
"""
Select `k` last frames
Args:
frames_tss (List[int]): timestamps of input frames
Returns:
List[int]: timestamps of selected frames
"""
return frame_tss[-self.k :]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment