# Copyright (c) OpenMMLab. All rights reserved. import copy import os import os.path as osp from collections import defaultdict from datetime import datetime import mmcv import numpy as np from mmcv.utils import print_log from ..core.evaluation.ava_utils import ava_eval, read_labelmap, results2csv from ..utils import get_root_logger from .base import BaseDataset from .builder import DATASETS @DATASETS.register_module() class AVADataset(BaseDataset): """AVA dataset for spatial temporal detection. Based on official AVA annotation files, the dataset loads raw frames, bounding boxes, proposals and applies specified transformations to return a dict containing the frame tensors and other information. This datasets can load information from the following files: .. code-block:: txt ann_file -> ava_{train, val}_{v2.1, v2.2}.csv exclude_file -> ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv label_file -> ava_action_list_{v2.1, v2.2}.pbtxt / ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt proposal_file -> ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl Particularly, the proposal_file is a pickle file which contains ``img_key`` (in format of ``{video_id},{timestamp}``). Example of a pickle file: .. code-block:: JSON { ... '0f39OWEqJ24,0902': array([[0.011 , 0.157 , 0.655 , 0.983 , 0.998163]]), '0f39OWEqJ24,0912': array([[0.054 , 0.088 , 0.91 , 0.998 , 0.068273], [0.016 , 0.161 , 0.519 , 0.974 , 0.984025], [0.493 , 0.283 , 0.981 , 0.984 , 0.983621]]), ... } Args: ann_file (str): Path to the annotation file like ``ava_{train, val}_{v2.1, v2.2}.csv``. exclude_file (str): Path to the excluded timestamp file like ``ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv``. pipeline (list[dict | callable]): A sequence of data transforms. label_file (str): Path to the label file like ``ava_action_list_{v2.1, v2.2}.pbtxt`` or ``ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt``. Default: None. filename_tmpl (str): Template for each filename. Default: 'img_{:05}.jpg'. start_index (int): Specify a start index for frames in consideration of different filename format. However, when taking videos as input, it should be set to 0, since frames loaded from videos count from 0. Default: 0. proposal_file (str): Path to the proposal file like ``ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl``. Default: None. person_det_score_thr (float): The threshold of person detection scores, bboxes with scores above the threshold will be used. Default: 0.9. Note that 0 <= person_det_score_thr <= 1. If no proposal has detection score larger than the threshold, the one with the largest detection score will be used. num_classes (int): The number of classes of the dataset. Default: 81. (AVA has 80 action classes, another 1-dim is added for potential usage) custom_classes (list[int]): A subset of class ids from origin dataset. Please note that 0 should NOT be selected, and ``num_classes`` should be equal to ``len(custom_classes) + 1`` data_prefix (str): Path to a directory where videos are held. Default: None. test_mode (bool): Store True when building test or validation dataset. Default: False. modality (str): Modality of data. Support 'RGB', 'Flow'. Default: 'RGB'. num_max_proposals (int): Max proposals number to store. Default: 1000. timestamp_start (int): The start point of included timestamps. The default value is referred from the official website. Default: 902. timestamp_end (int): The end point of included timestamps. The default value is referred from the official website. Default: 1798. fps (int): Overrides the default FPS for the dataset. Default: 30. """ def __init__(self, ann_file, exclude_file, pipeline, label_file=None, filename_tmpl='img_{:05}.jpg', start_index=0, proposal_file=None, person_det_score_thr=0.9, num_classes=81, custom_classes=None, data_prefix=None, test_mode=False, modality='RGB', num_max_proposals=1000, timestamp_start=900, timestamp_end=1800, fps=30): # since it inherits from `BaseDataset`, some arguments # should be assigned before performing `load_annotations()` self._FPS = fps # Keep this as standard self.custom_classes = custom_classes if custom_classes is not None: assert num_classes == len(custom_classes) + 1 assert 0 not in custom_classes _, class_whitelist = read_labelmap(open(label_file)) assert set(custom_classes).issubset(class_whitelist) self.custom_classes = tuple([0] + custom_classes) self.exclude_file = exclude_file self.label_file = label_file self.proposal_file = proposal_file assert 0 <= person_det_score_thr <= 1, ( 'The value of ' 'person_det_score_thr should in [0, 1]. ') self.person_det_score_thr = person_det_score_thr self.num_classes = num_classes self.filename_tmpl = filename_tmpl self.num_max_proposals = num_max_proposals self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end self.logger = get_root_logger() super().__init__( ann_file, pipeline, data_prefix, test_mode, start_index=start_index, modality=modality, num_classes=num_classes) if self.proposal_file is not None: self.proposals = mmcv.load(self.proposal_file) else: self.proposals = None if not test_mode: valid_indexes = self.filter_exclude_file() self.logger.info( f'{len(valid_indexes)} out of {len(self.video_infos)} ' f'frames are valid.') self.video_infos = [self.video_infos[i] for i in valid_indexes] def parse_img_record(self, img_records): """Merge image records of the same entity at the same time. Args: img_records (list[dict]): List of img_records (lines in AVA annotations). Returns: tuple(list): A tuple consists of lists of bboxes, action labels and entity_ids """ bboxes, labels, entity_ids = [], [], [] while len(img_records) > 0: img_record = img_records[0] num_img_records = len(img_records) selected_records = [ x for x in img_records if np.array_equal(x['entity_box'], img_record['entity_box']) ] num_selected_records = len(selected_records) img_records = [ x for x in img_records if not np.array_equal(x['entity_box'], img_record['entity_box']) ] assert len(img_records) + num_selected_records == num_img_records bboxes.append(img_record['entity_box']) valid_labels = np.array([ selected_record['label'] for selected_record in selected_records ]) # The format can be directly used by BCELossWithLogits label = np.zeros(self.num_classes, dtype=np.float32) label[valid_labels] = 1. labels.append(label) entity_ids.append(img_record['entity_id']) bboxes = np.stack(bboxes) labels = np.stack(labels) entity_ids = np.stack(entity_ids) return bboxes, labels, entity_ids def filter_exclude_file(self): """Filter out records in the exclude_file.""" valid_indexes = [] if self.exclude_file is None: valid_indexes = list(range(len(self.video_infos))) else: exclude_video_infos = [ x.strip().split(',') for x in open(self.exclude_file) ] for i, video_info in enumerate(self.video_infos): valid_indexes.append(i) for video_id, timestamp in exclude_video_infos: if (video_info['video_id'] == video_id and video_info['timestamp'] == int(timestamp)): valid_indexes.pop() break return valid_indexes def load_annotations(self): """Load AVA annotations.""" video_infos = [] records_dict_by_img = defaultdict(list) with open(self.ann_file, 'r') as fin: for line in fin: line_split = line.strip().split(',') label = int(line_split[6]) if self.custom_classes is not None: if label not in self.custom_classes: continue label = self.custom_classes.index(label) video_id = line_split[0] timestamp = int(line_split[1]) img_key = f'{video_id},{timestamp:04d}' entity_box = np.array(list(map(float, line_split[2:6]))) entity_id = int(line_split[7]) shot_info = (0, (self.timestamp_end - self.timestamp_start) * self._FPS) video_info = dict( video_id=video_id, timestamp=timestamp, entity_box=entity_box, label=label, entity_id=entity_id, shot_info=shot_info) records_dict_by_img[img_key].append(video_info) for img_key in records_dict_by_img: video_id, timestamp = img_key.split(',') bboxes, labels, entity_ids = self.parse_img_record( records_dict_by_img[img_key]) ann = dict( gt_bboxes=bboxes, gt_labels=labels, entity_ids=entity_ids) frame_dir = video_id if self.data_prefix is not None: frame_dir = osp.join(self.data_prefix, frame_dir) video_info = dict( frame_dir=frame_dir, video_id=video_id, timestamp=int(timestamp), img_key=img_key, shot_info=shot_info, fps=self._FPS, ann=ann) video_infos.append(video_info) return video_infos def prepare_train_frames(self, idx): """Prepare the frames for training given the index.""" results = copy.deepcopy(self.video_infos[idx]) img_key = results['img_key'] results['filename_tmpl'] = self.filename_tmpl results['modality'] = self.modality results['start_index'] = self.start_index results['timestamp_start'] = self.timestamp_start results['timestamp_end'] = self.timestamp_end if self.proposals is not None: if img_key not in self.proposals: results['proposals'] = np.array([[0, 0, 1, 1]]) results['scores'] = np.array([1]) else: proposals = self.proposals[img_key] assert proposals.shape[-1] in [4, 5] if proposals.shape[-1] == 5: thr = min(self.person_det_score_thr, max(proposals[:, 4])) positive_inds = (proposals[:, 4] >= thr) proposals = proposals[positive_inds] proposals = proposals[:self.num_max_proposals] results['proposals'] = proposals[:, :4] results['scores'] = proposals[:, 4] else: proposals = proposals[:self.num_max_proposals] results['proposals'] = proposals ann = results.pop('ann') results['gt_bboxes'] = ann['gt_bboxes'] results['gt_labels'] = ann['gt_labels'] results['entity_ids'] = ann['entity_ids'] return self.pipeline(results) def prepare_test_frames(self, idx): """Prepare the frames for testing given the index.""" results = copy.deepcopy(self.video_infos[idx]) img_key = results['img_key'] results['filename_tmpl'] = self.filename_tmpl results['modality'] = self.modality results['start_index'] = self.start_index results['timestamp_start'] = self.timestamp_start results['timestamp_end'] = self.timestamp_end if self.proposals is not None: if img_key not in self.proposals: results['proposals'] = np.array([[0, 0, 1, 1]]) results['scores'] = np.array([1]) else: proposals = self.proposals[img_key] assert proposals.shape[-1] in [4, 5] if proposals.shape[-1] == 5: thr = min(self.person_det_score_thr, max(proposals[:, 4])) positive_inds = (proposals[:, 4] >= thr) proposals = proposals[positive_inds] proposals = proposals[:self.num_max_proposals] results['proposals'] = proposals[:, :4] results['scores'] = proposals[:, 4] else: proposals = proposals[:self.num_max_proposals] results['proposals'] = proposals ann = results.pop('ann') # Follow the mmdet variable naming style. results['gt_bboxes'] = ann['gt_bboxes'] results['gt_labels'] = ann['gt_labels'] results['entity_ids'] = ann['entity_ids'] return self.pipeline(results) def dump_results(self, results, out): """Dump predictions into a csv file.""" assert out.endswith('csv') results2csv(self, results, out, self.custom_classes) def evaluate(self, results, metrics=('mAP', ), metric_options=None, logger=None): """Evaluate the prediction results and report mAP.""" assert len(metrics) == 1 and metrics[0] == 'mAP', ( 'For evaluation on AVADataset, you need to use metrics "mAP" ' 'See https://github.com/open-mmlab/mmaction2/pull/567 ' 'for more info.') time_now = datetime.now().strftime('%Y%m%d_%H%M%S') temp_file = f'AVA_{time_now}_result.csv' results2csv(self, results, temp_file, self.custom_classes) ret = {} for metric in metrics: msg = f'Evaluating {metric} ...' if logger is None: msg = '\n' + msg print_log(msg, logger=logger) eval_result = ava_eval( temp_file, metric, self.label_file, self.ann_file, self.exclude_file, custom_classes=self.custom_classes) log_msg = [] for k, v in eval_result.items(): log_msg.append(f'\n{k}\t{v: .4f}') log_msg = ''.join(log_msg) print_log(log_msg, logger=logger) ret.update(eval_result) os.remove(temp_file) return ret