Commit afe88104 authored by lishj6's avatar lishj6 🏸
Browse files

init0905

parent a48c4071
import prettytable
from typing import Dict, List, Optional
from time import time
from copy import deepcopy
from multiprocessing import Pool
from logging import Logger
from functools import partial, cached_property
import numpy as np
from numpy.typing import NDArray
from shapely.geometry import LineString
import mmcv
from mmcv import Config
from mmdet.datasets import build_dataset, build_dataloader
from .AP import instance_match, average_precision
INTERP_NUM = 200 # number of points to interpolate during evaluation
THRESHOLDS = [0.5, 1.0, 1.5] # AP thresholds
N_WORKERS = 16 # num workers to parallel
class VectorEvaluate(object):
"""Evaluator for vectorized map.
Args:
dataset_cfg (Config): dataset cfg for gt
n_workers (int): num workers to parallel
"""
def __init__(self, dataset_cfg: Config, n_workers: int=N_WORKERS) -> None:
self.dataset = build_dataset(dataset_cfg)
self.dataloader = build_dataloader(
self.dataset, samples_per_gpu=1, workers_per_gpu=n_workers, shuffle=False, dist=False)
classes = self.dataset.MAP_CLASSES
self.cat2id = {cls: i for i, cls in enumerate(classes)}
self.id2cat = {v: k for k, v in self.cat2id.items()}
self.n_workers = n_workers
self.thresholds = [0.5, 1.0, 1.5]
@cached_property
def gts(self) -> Dict[str, Dict[int, List[NDArray]]]:
print('collecting gts...')
gts = {}
pbar = mmcv.ProgressBar(len(self.dataloader))
for data in self.dataloader:
token = deepcopy(data['img_metas'].data[0][0]['token'])
gt = deepcopy(data['vectors'].data[0][0])
gts[token] = gt
pbar.update()
del data # avoid dataloader memory crash
return gts
def interp_fixed_num(self,
vector: NDArray,
num_pts: int) -> NDArray:
''' Interpolate a polyline.
Args:
vector (array): line coordinates, shape (M, 2)
num_pts (int):
Returns:
sampled_points (array): interpolated coordinates
'''
line = LineString(vector)
distances = np.linspace(0, line.length, num_pts)
sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze()
return sampled_points
def interp_fixed_dist(self,
vector: NDArray,
sample_dist: float) -> NDArray:
''' Interpolate a line at fixed interval.
Args:
vector (LineString): vector
sample_dist (float): sample interval
Returns:
points (array): interpolated points, shape (N, 2)
'''
line = LineString(vector)
distances = list(np.arange(sample_dist, line.length, sample_dist))
# make sure to sample at least two points when sample_dist > line.length
distances = [0,] + distances + [line.length,]
sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze()
return sampled_points
def _evaluate_single(self,
pred_vectors: List,
scores: List,
groundtruth: List,
thresholds: List,
metric: str='metric') -> Dict[int, NDArray]:
''' Do single-frame matching for one class.
Args:
pred_vectors (List): List[vector(ndarray) (different length)],
scores (List): List[score(float)]
groundtruth (List): List of vectors
thresholds (List): List of thresholds
Returns:
tp_fp_score_by_thr (Dict): matching results at different thresholds
e.g. {0.5: (M, 2), 1.0: (M, 2), 1.5: (M, 2)}
'''
pred_lines = []
# interpolate predictions
for vector in pred_vectors:
vector = np.array(vector)
vector_interp = self.interp_fixed_num(vector, INTERP_NUM)
pred_lines.append(vector_interp)
if pred_lines:
pred_lines = np.stack(pred_lines)
else:
pred_lines = np.zeros((0, INTERP_NUM, 2))
# interpolate groundtruth
gt_lines = []
for vector in groundtruth:
vector_interp = self.interp_fixed_num(vector, INTERP_NUM)
gt_lines.append(vector_interp)
if gt_lines:
gt_lines = np.stack(gt_lines)
else:
gt_lines = np.zeros((0, INTERP_NUM, 2))
scores = np.array(scores)
tp_fp_list = instance_match(pred_lines, scores, gt_lines, thresholds, metric) # (M, 2)
tp_fp_score_by_thr = {}
for i, thr in enumerate(thresholds):
tp, fp = tp_fp_list[i]
tp_fp_score = np.hstack([tp[:, None], fp[:, None], scores[:, None]])
tp_fp_score_by_thr[thr] = tp_fp_score
return tp_fp_score_by_thr # {0.5: (M, 2), 1.0: (M, 2), 1.5: (M, 2)}
def evaluate(self,
result_path: str,
metric: str='chamfer',
logger: Optional[Logger]=None) -> Dict[str, float]:
''' Do evaluation for a submission file and print evalution results to `logger` if specified.
The submission will be aligned by tokens before evaluation. We use multi-worker to speed up.
Args:
result_path (str): path to submission file
metric (str): distance metric. Default: 'chamfer'
logger (Logger): logger to print evaluation result, Default: None
Returns:
new_result_dict (Dict): evaluation results. AP by categories.
'''
results = mmcv.load(result_path)
results = results['results']
# re-group samples and gt by label
samples_by_cls = {label: [] for label in self.id2cat.keys()}
num_gts = {label: 0 for label in self.id2cat.keys()}
num_preds = {label: 0 for label in self.id2cat.keys()}
# align by token
for token, gt in self.gts.items():
if token in results.keys():
pred = results[token]
else:
pred = {'vectors': [], 'scores': [], 'labels': []}
# for every sample
vectors_by_cls = {label: [] for label in self.id2cat.keys()}
scores_by_cls = {label: [] for label in self.id2cat.keys()}
for i in range(len(pred['labels'])):
# i-th pred line in sample
label = pred['labels'][i]
vector = pred['vectors'][i]
score = pred['scores'][i]
vectors_by_cls[label].append(vector)
scores_by_cls[label].append(score)
for label in self.id2cat.keys():
new_sample = (vectors_by_cls[label], scores_by_cls[label], gt[label])
num_gts[label] += len(gt[label])
num_preds[label] += len(scores_by_cls[label])
samples_by_cls[label].append(new_sample)
result_dict = {}
print(f'\nevaluating {len(self.id2cat)} categories...')
start = time()
if self.n_workers > 0:
pool = Pool(self.n_workers)
sum_mAP = 0
pbar = mmcv.ProgressBar(len(self.id2cat))
for label in self.id2cat.keys():
samples = samples_by_cls[label] # List[(pred_lines, scores, gts)]
result_dict[self.id2cat[label]] = {
'num_gts': num_gts[label],
'num_preds': num_preds[label]
}
sum_AP = 0
fn = partial(self._evaluate_single, thresholds=self.thresholds, metric=metric)
if self.n_workers > 0 and len(samples) > 81:
tpfp_score_list = pool.starmap(fn, samples)
else:
tpfp_score_list = []
for sample in samples:
tpfp_score_list.append(fn(*sample))
for thr in self.thresholds:
tp_fp_score = [i[thr] for i in tpfp_score_list]
tp_fp_score = np.vstack(tp_fp_score) # (num_dets, 3)
sort_inds = np.argsort(-tp_fp_score[:, -1])
tp = tp_fp_score[sort_inds, 0] # (num_dets,)
fp = tp_fp_score[sort_inds, 1] # (num_dets,)
tp = np.cumsum(tp, axis=0)
fp = np.cumsum(fp, axis=0)
eps = np.finfo(np.float32).eps
recalls = tp / np.maximum(num_gts[label], eps)
precisions = tp / np.maximum((tp + fp), eps)
AP = average_precision(recalls, precisions, 'area')
sum_AP += AP
result_dict[self.id2cat[label]].update({f'AP@{thr}': AP})
pbar.update()
AP = sum_AP / len(self.thresholds)
sum_mAP += AP
result_dict[self.id2cat[label]].update({f'AP': AP})
if self.n_workers > 0:
pool.close()
mAP = sum_mAP / len(self.id2cat.keys())
result_dict.update({'mAP': mAP})
print(f"finished in {time() - start:.2f}s")
# print results
table = prettytable.PrettyTable(['category', 'num_preds', 'num_gts'] +
[f'AP@{thr}' for thr in self.thresholds] + ['AP'])
for label in self.id2cat.keys():
table.add_row([
self.id2cat[label],
result_dict[self.id2cat[label]]['num_preds'],
result_dict[self.id2cat[label]]['num_gts'],
*[round(result_dict[self.id2cat[label]][f'AP@{thr}'], 4) for thr in self.thresholds],
round(result_dict[self.id2cat[label]]['AP'], 4),
])
from mmcv.utils import print_log
print_log('\n'+str(table), logger=logger)
mAP_normal = 0
for label in self.id2cat.keys():
for thr in self.thresholds:
mAP_normal += result_dict[self.id2cat[label]][f'AP@{thr}']
mAP_normal = mAP_normal / 9
print_log(f'mAP_normal = {mAP_normal:.4f}\n', logger=logger)
# print_log(f'mAP_hard = {mAP_easy:.4f}\n', logger=logger)
new_result_dict = {}
for name in self.cat2id:
new_result_dict[name] = result_dict[name]['AP']
new_result_dict['mAP_normal'] = mAP_normal
return new_result_dict
\ No newline at end of file
# nuScenes dev-kit.
# Code written by Holger Caesar & Oscar Beijbom, 2018.
import argparse
import json
import os
import random
import time
import tqdm
from typing import Tuple, Dict, Any
import numpy as np
from nuscenes import NuScenes
from nuscenes.eval.common.config import config_factory
from nuscenes.eval.common.data_classes import EvalBoxes
from nuscenes.eval.common.loaders import add_center_dist, filter_eval_boxes
from nuscenes.eval.detection.algo import accumulate, calc_ap, calc_tp
from nuscenes.eval.detection.constants import DETECTION_NAMES, ATTRIBUTE_NAMES, TP_METRICS
from nuscenes.eval.detection.data_classes import DetectionConfig, DetectionMetrics, DetectionBox, \
DetectionMetricDataList, DetectionMetricData
from nuscenes.eval.detection.render import summary_plot, class_pr_curve, class_tp_curve, dist_pr_curve, visualize_sample
from nuscenes.prediction import PredictHelper, convert_local_coords_to_global
from nuscenes.utils.splits import create_splits_scenes
from nuscenes.eval.detection.utils import category_to_detection_name
from nuscenes.eval.common.utils import quaternion_yaw, Quaternion
from nuscenes.eval.common.utils import center_distance, scale_iou, yaw_diff, velocity_l2, attr_acc, cummean
from .motion_utils import MotionBox, load_prediction, load_gt, accumulate
MOTION_TP_METRICS = ['min_ade_err', 'min_fde_err', 'miss_rate_err']
class MotionEval:
"""
This is the official nuScenes detection evaluation code.
Results are written to the provided output_dir.
nuScenes uses the following detection metrics:
- Mean Average Precision (mAP): Uses center-distance as matching criterion; averaged over distance thresholds.
- True Positive (TP) metrics: Average of translation, velocity, scale, orientation and attribute errors.
- nuScenes Detection Score (NDS): The weighted sum of the above.
Here is an overview of the functions in this method:
- init: Loads GT annotations and predictions stored in JSON format and filters the boxes.
- run: Performs evaluation and dumps the metric data to disk.
- render: Renders various plots and dumps to disk.
We assume that:
- Every sample_token is given in the results, although there may be not predictions for that sample.
Please see https://www.nuscenes.org/object-detection for more details.
"""
def __init__(self,
nusc: NuScenes,
config: DetectionConfig,
result_path: str,
eval_set: str,
output_dir: str = None,
verbose: bool = True,
seconds: int = 12):
"""
Initialize a DetectionEval object.
:param nusc: A NuScenes object.
:param config: A DetectionConfig object.
:param result_path: Path of the nuScenes JSON result file.
:param eval_set: The dataset split to evaluate on, e.g. train, val or test.
:param output_dir: Folder to save plots and results to.
:param verbose: Whether to print to stdout.
"""
self.nusc = nusc
self.result_path = result_path
self.eval_set = eval_set
self.output_dir = output_dir
self.verbose = verbose
self.cfg = config
# Check result file exists.
# assert os.path.exists(result_path), 'Error: The result file does not exist!'
# Make dirs.
self.plot_dir = os.path.join(self.output_dir, 'plots')
if not os.path.isdir(self.output_dir):
os.makedirs(self.output_dir)
if not os.path.isdir(self.plot_dir):
os.makedirs(self.plot_dir)
# Load data.
if verbose:
print('Initializing nuScenes detection evaluation')
self.pred_boxes, self.meta = load_prediction(self.result_path, self.cfg.max_boxes_per_sample, MotionBox,
verbose=verbose)
self.gt_boxes = load_gt(self.nusc, self.eval_set, MotionBox, verbose=verbose, seconds=seconds)
assert set(self.pred_boxes.sample_tokens) == set(self.gt_boxes.sample_tokens), \
"Samples in split doesn't match samples in predictions."
# Add center distances.
self.pred_boxes = add_center_dist(nusc, self.pred_boxes)
self.gt_boxes = add_center_dist(nusc, self.gt_boxes)
# Filter boxes (distance, points per box, etc.).
if verbose:
print('Filtering predictions')
self.pred_boxes = filter_eval_boxes(nusc, self.pred_boxes, self.cfg.class_range, verbose=verbose)
if verbose:
print('Filtering ground truth annotations')
self.gt_boxes = filter_eval_boxes(nusc, self.gt_boxes, self.cfg.class_range, verbose=verbose)
self.sample_tokens = self.gt_boxes.sample_tokens
def evaluate(self) -> Tuple[DetectionMetrics, DetectionMetricDataList]:
"""
Performs the actual evaluation.
:return: A tuple of high-level and the raw metric data.
"""
start_time = time.time()
self.cfg.class_names = ['car', 'pedestrian']
self.cfg.dist_ths = [2.0]
# -----------------------------------
# Step 1: Accumulate metric data for all classes and distance thresholds.
# -----------------------------------
if self.verbose:
print('Accumulating metric data...')
metric_data_list = DetectionMetricDataList()
metrics = {}
for class_name in self.cfg.class_names:
for dist_th in self.cfg.dist_ths:
md, EPA, EPA_ = accumulate(self.gt_boxes, self.pred_boxes, class_name, self.cfg.dist_fcn_callable, dist_th)
metric_data_list.set(class_name, dist_th, md)
metrics[f'{class_name}_EPA'] = EPA_
# -----------------------------------
# Step 2: Calculate metrics from the data.
# -----------------------------------
if self.verbose:
print('Calculating metrics...')
for class_name in self.cfg.class_names:
# Compute TP metrics.
for metric_name in MOTION_TP_METRICS:
metric_data = metric_data_list[(class_name, self.cfg.dist_th_tp)]
tp = calc_tp(metric_data, self.cfg.min_recall, metric_name)
metrics[f'{class_name}_{metric_name}'] = tp
return metrics, metric_data_list
def render(self, metrics: DetectionMetrics, md_list: DetectionMetricDataList) -> None:
"""
Renders various PR and TP curves.
:param metrics: DetectionMetrics instance.
:param md_list: DetectionMetricDataList instance.
"""
if self.verbose:
print('Rendering PR and TP curves')
def savepath(name):
return os.path.join(self.plot_dir, name + '.pdf')
summary_plot(md_list, metrics, min_precision=self.cfg.min_precision, min_recall=self.cfg.min_recall,
dist_th_tp=self.cfg.dist_th_tp, savepath=savepath('summary'))
for detection_name in self.cfg.class_names:
class_pr_curve(md_list, metrics, detection_name, self.cfg.min_precision, self.cfg.min_recall,
savepath=savepath(detection_name + '_pr'))
class_tp_curve(md_list, metrics, detection_name, self.cfg.min_recall, self.cfg.dist_th_tp,
savepath=savepath(detection_name + '_tp'))
for dist_th in self.cfg.dist_ths:
dist_pr_curve(md_list, metrics, dist_th, self.cfg.min_precision, self.cfg.min_recall,
savepath=savepath('dist_pr_' + str(dist_th)))
def main(self,
plot_examples: int = 0,
render_curves: bool = True) -> Dict[str, Any]:
"""
Main function that loads the evaluation code, visualizes samples, runs the evaluation and renders stat plots.
:param plot_examples: How many example visualizations to write to disk.
:param render_curves: Whether to render PR and TP curves to disk.
:return: A dict that stores the high-level metrics and meta data.
"""
if plot_examples > 0:
# Select a random but fixed subset to plot.
random.seed(42)
sample_tokens = list(self.sample_tokens)
random.shuffle(sample_tokens)
sample_tokens = sample_tokens[:plot_examples]
# Visualize samples.
example_dir = os.path.join(self.output_dir, 'examples')
if not os.path.isdir(example_dir):
os.mkdir(example_dir)
for sample_token in sample_tokens:
visualize_sample(self.nusc,
sample_token,
self.gt_boxes if self.eval_set != 'test' else EvalBoxes(),
# Don't render test GT.
self.pred_boxes,
eval_range=max(self.cfg.class_range.values()),
savepath=os.path.join(example_dir, '{}.png'.format(sample_token)))
# Run evaluation.
metrics, metric_data_list = self.evaluate()
return metrics
class NuScenesEval(MotionEval):
"""
Dummy class for backward-compatibility. Same as MotionEval.
"""
if __name__ == "__main__":
# Settings.
parser = argparse.ArgumentParser(description='Evaluate nuScenes detection results.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('result_path', type=str, help='The submission as a JSON file.')
parser.add_argument('--output_dir', type=str, default='~/nuscenes-metrics',
help='Folder to store result metrics, graphs and example visualizations.')
parser.add_argument('--eval_set', type=str, default='val',
help='Which dataset split to evaluate on, train, val or test.')
parser.add_argument('--dataroot', type=str, default='/data/sets/nuscenes',
help='Default nuScenes data directory.')
parser.add_argument('--version', type=str, default='v1.0-trainval',
help='Which version of the nuScenes dataset to evaluate on, e.g. v1.0-trainval.')
parser.add_argument('--config_path', type=str, default='',
help='Path to the configuration file.'
'If no path given, the CVPR 2019 configuration will be used.')
parser.add_argument('--plot_examples', type=int, default=10,
help='How many example visualizations to write to disk.')
parser.add_argument('--render_curves', type=int, default=1,
help='Whether to render PR and TP curves to disk.')
parser.add_argument('--verbose', type=int, default=1,
help='Whether to print to stdout.')
args = parser.parse_args()
result_path_ = os.path.expanduser(args.result_path)
output_dir_ = os.path.expanduser(args.output_dir)
eval_set_ = args.eval_set
dataroot_ = args.dataroot
version_ = args.version
config_path = args.config_path
plot_examples_ = args.plot_examples
render_curves_ = bool(args.render_curves)
verbose_ = bool(args.verbose)
if config_path == '':
cfg_ = config_factory('detection_cvpr_2019')
else:
with open(config_path, 'r') as _f:
cfg_ = DetectionConfig.deserialize(json.load(_f))
nusc_ = NuScenes(version=version_, verbose=verbose_, dataroot=dataroot_)
nusc_eval = DetectionEval(nusc_, config=cfg_, result_path=result_path_, eval_set=eval_set_,
output_dir=output_dir_, verbose=verbose_)
nusc_eval.main(plot_examples=plot_examples_, render_curves=render_curves_)
# nuScenes dev-kit.
# Code written by Holger Caesar & Oscar Beijbom, 2018.
import argparse
import json
import os
import random
import time
import tqdm
from typing import Tuple, Dict, Any, Callable
import numpy as np
from nuscenes import NuScenes
from nuscenes.eval.common.config import config_factory
from nuscenes.eval.common.data_classes import EvalBoxes
from nuscenes.eval.common.loaders import add_center_dist, filter_eval_boxes
from nuscenes.eval.detection.algo import calc_ap, calc_tp
from nuscenes.eval.detection.constants import DETECTION_NAMES, ATTRIBUTE_NAMES, TP_METRICS
from nuscenes.eval.detection.data_classes import DetectionConfig, DetectionMetrics, DetectionBox, \
DetectionMetricDataList, DetectionMetricData
from nuscenes.eval.detection.render import summary_plot, class_pr_curve, class_tp_curve, dist_pr_curve, visualize_sample
from nuscenes.prediction import PredictHelper, convert_local_coords_to_global
from nuscenes.utils.splits import create_splits_scenes
from nuscenes.eval.detection.utils import category_to_detection_name
from nuscenes.eval.common.utils import quaternion_yaw, Quaternion
from nuscenes.eval.common.utils import center_distance, scale_iou, yaw_diff, velocity_l2, attr_acc, cummean
motion_name_mapping = {
'car': 'car',
'truck': 'car',
'construction_vehicle': 'car',
'bus': 'car',
'trailer': 'car',
'motorcycle': 'car',
'bicycle': 'car',
'pedestrian': 'pedestrian',
'traffic_cone': 'barrier',
'barrier': 'barrier',
}
class MotionBox(DetectionBox):
""" Data class used during detection evaluation. Can be a prediction or ground truth."""
def __init__(self,
sample_token: str = "",
translation: Tuple[float, float, float] = (0, 0, 0),
size: Tuple[float, float, float] = (0, 0, 0),
rotation: Tuple[float, float, float, float] = (0, 0, 0, 0),
velocity: Tuple[float, float] = (0, 0),
ego_translation: [float, float, float] = (0, 0, 0), # Translation to ego vehicle in meters.
num_pts: int = -1, # Nbr. LIDAR or RADAR inside the box. Only for gt boxes.
detection_name: str = 'car', # The class name used in the detection challenge.
detection_score: float = -1.0, # GT samples do not have a score.
attribute_name: str = '', # Box attribute. Each box can have at most 1 attribute.
traj=None):
super().__init__(sample_token, translation, size, rotation, velocity, ego_translation, num_pts)
assert detection_name is not None, 'Error: detection_name cannot be empty!'
assert detection_name in DETECTION_NAMES, 'Error: Unknown detection_name %s' % detection_name
assert attribute_name in ATTRIBUTE_NAMES or attribute_name == '', \
'Error: Unknown attribute_name %s' % attribute_name
assert type(detection_score) == float, 'Error: detection_score must be a float!'
assert not np.any(np.isnan(detection_score)), 'Error: detection_score may not be NaN!'
# Assign.
self.detection_name = detection_name
self.detection_score = detection_score
self.attribute_name = attribute_name
self.traj = traj
def __eq__(self, other):
return (self.sample_token == other.sample_token and
self.translation == other.translation and
self.size == other.size and
self.rotation == other.rotation and
self.velocity == other.velocity and
self.ego_translation == other.ego_translation and
self.num_pts == other.num_pts and
self.detection_name == other.detection_name and
self.detection_score == other.detection_score and
self.attribute_name == other.attribute_name and
np.all(self.traj == other.traj))
def serialize(self) -> dict:
""" Serialize instance into json-friendly format. """
return {
'sample_token': self.sample_token,
'translation': self.translation,
'size': self.size,
'rotation': self.rotation,
'velocity': self.velocity,
'ego_translation': self.ego_translation,
'num_pts': self.num_pts,
'detection_name': self.detection_name,
'detection_score': self.detection_score,
'attribute_name': self.attribute_name,
'traj': self.traj,
}
@classmethod
def deserialize(cls, content: dict):
""" Initialize from serialized content. """
return cls(sample_token=content['sample_token'],
translation=tuple(content['translation']),
size=tuple(content['size']),
rotation=tuple(content['rotation']),
velocity=tuple(content['velocity']),
ego_translation=(0.0, 0.0, 0.0) if 'ego_translation' not in content
else tuple(content['ego_translation']),
num_pts=-1 if 'num_pts' not in content else int(content['num_pts']),
detection_name=content['detection_name'],
detection_score=-1.0 if 'detection_score' not in content else float(content['detection_score']),
attribute_name=content['attribute_name'],
traj=content['trajs'],)
def load_prediction(result_path: str, max_boxes_per_sample: int, box_cls, verbose: bool = False) \
-> Tuple[EvalBoxes, Dict]:
"""
Loads object predictions from file.
:param result_path: Path to the .json result file provided by the user.
:param max_boxes_per_sample: Maximim number of boxes allowed per sample.
:param box_cls: Type of box to load, e.g. DetectionBox or TrackingBox.
:param verbose: Whether to print messages to stdout.
:return: The deserialized results and meta data.
"""
# Load from file and check that the format is correct.
# with open(result_path) as f:
# data = json.load(f)
data = result_path
assert 'results' in data, 'Error: No field `results` in result file. Please note that the result format changed.' \
'See https://www.nuscenes.org/object-detection for more information.'
# motion name mapping
for key in data['results'].keys():
for i in range(len(data['results'][key])):
cls_name = data['results'][key][i]['detection_name']
if cls_name in motion_name_mapping:
cls_name = motion_name_mapping[cls_name]
data['results'][key][i]['detection_name'] = cls_name
# Deserialize results and get meta data.
all_results = EvalBoxes.deserialize(data['results'], box_cls)
meta = data['meta']
if verbose:
print("Loaded results from {}. Found detections for {} samples."
.format(result_path, len(all_results.sample_tokens)))
# Check that each sample has no more than x predicted boxes.
for sample_token in all_results.sample_tokens:
assert len(all_results.boxes[sample_token]) <= max_boxes_per_sample, \
"Error: Only <= %d boxes per sample allowed!" % max_boxes_per_sample
return all_results, meta
def load_gt(nusc: NuScenes, eval_split: str, box_cls, verbose: bool = False, seconds: int = 12) -> EvalBoxes:
"""
Loads ground truth boxes from DB.
:param nusc: A NuScenes instance.
:param eval_split: The evaluation split for which we load GT boxes.
:param box_cls: Type of box to load, e.g. DetectionBox or TrackingBox.
:param verbose: Whether to print messages to stdout.
:return: The GT boxes.
"""
predict_helper = PredictHelper(nusc)
# Init.
if box_cls == MotionBox:
attribute_map = {a['token']: a['name'] for a in nusc.attribute}
if verbose:
print('Loading annotations for {} split from nuScenes version: {}'.format(eval_split, nusc.version))
# Read out all sample_tokens in DB.
sample_tokens_all = [s['token'] for s in nusc.sample]
assert len(sample_tokens_all) > 0, "Error: Database has no samples!"
# Only keep samples from this split.
splits = create_splits_scenes()
# Check compatibility of split with nusc_version.
version = nusc.version
if eval_split in {'train', 'val', 'train_detect', 'train_track'}:
assert version.endswith('trainval'), \
'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version)
elif eval_split in {'mini_train', 'mini_val'}:
assert version.endswith('mini'), \
'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version)
elif eval_split == 'test':
assert version.endswith('test'), \
'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version)
else:
raise ValueError('Error: Requested split {} which this function cannot map to the correct NuScenes version.'
.format(eval_split))
if eval_split == 'test':
# Check that you aren't trying to cheat :).
assert len(nusc.sample_annotation) > 0, \
'Error: You are trying to evaluate on the test set but you do not have the annotations!'
sample_tokens = []
for sample_token in sample_tokens_all:
scene_token = nusc.get('sample', sample_token)['scene_token']
scene_record = nusc.get('scene', scene_token)
if scene_record['name'] in splits[eval_split]:
sample_tokens.append(sample_token)
all_annotations = EvalBoxes()
# Load annotations and filter predictions and annotations.
tracking_id_set = set()
for sample_token in tqdm.tqdm(sample_tokens, leave=verbose):
sample = nusc.get('sample', sample_token)
sample_annotation_tokens = sample['anns']
sample_boxes = []
for sample_annotation_token in sample_annotation_tokens:
sample_annotation = nusc.get('sample_annotation', sample_annotation_token)
if box_cls == MotionBox:
# Get label name in detection task and filter unused labels.
detection_name = category_to_detection_name(sample_annotation['category_name'])
# motion name mapping
if detection_name in motion_name_mapping:
detection_name = motion_name_mapping[detection_name]
if detection_name is None:
continue
# Get attribute_name.
attr_tokens = sample_annotation['attribute_tokens']
attr_count = len(attr_tokens)
if attr_count == 0:
attribute_name = ''
elif attr_count == 1:
attribute_name = attribute_map[attr_tokens[0]]
else:
raise Exception('Error: GT annotations must not have more than one attribute!')
# get future trajs
instance_token = nusc.get('sample_annotation', sample_annotation['token'])['instance_token']
fut_traj_local = predict_helper.get_future_for_agent(
instance_token,
sample_token,
seconds=seconds,
in_agent_frame=True
)
if fut_traj_local.shape[0] > 0:
_, boxes, _ = nusc.get_sample_data(sample['data']['LIDAR_TOP'], selected_anntokens=[sample_annotation['token']])
box = boxes[0]
trans = box.center
rot = Quaternion(matrix=box.rotation_matrix)
fut_traj_scence_centric = convert_local_coords_to_global(fut_traj_local, trans, rot)
else:
fut_traj_scence_centric = np.zeros((0,))
sample_boxes.append(
box_cls(
sample_token=sample_token,
translation=sample_annotation['translation'],
size=sample_annotation['size'],
rotation=sample_annotation['rotation'],
velocity=nusc.box_velocity(sample_annotation['token'])[:2],
num_pts=sample_annotation['num_lidar_pts'] + sample_annotation['num_radar_pts'],
detection_name=detection_name,
detection_score=-1.0, # GT samples do not have a score.
attribute_name=attribute_name,
traj=fut_traj_scence_centric
)
)
elif box_cls == TrackingBox:
# Use nuScenes token as tracking id.
tracking_id = sample_annotation['instance_token']
tracking_id_set.add(tracking_id)
# Get label name in detection task and filter unused labels.
# Import locally to avoid errors when motmetrics package is not installed.
from nuscenes.eval.tracking.utils import category_to_tracking_name
tracking_name = category_to_tracking_name(sample_annotation['category_name'])
if tracking_name is None:
continue
sample_boxes.append(
box_cls(
sample_token=sample_token,
translation=sample_annotation['translation'],
size=sample_annotation['size'],
rotation=sample_annotation['rotation'],
velocity=nusc.box_velocity(sample_annotation['token'])[:2],
num_pts=sample_annotation['num_lidar_pts'] + sample_annotation['num_radar_pts'],
tracking_id=tracking_id,
tracking_name=tracking_name,
tracking_score=-1.0 # GT samples do not have a score.
)
)
else:
raise NotImplementedError('Error: Invalid box_cls %s!' % box_cls)
all_annotations.add_boxes(sample_token, sample_boxes)
if verbose:
print("Loaded ground truth annotations for {} samples.".format(len(all_annotations.sample_tokens)))
return all_annotations
def accumulate(gt_boxes: EvalBoxes,
pred_boxes: EvalBoxes,
class_name: str,
dist_fcn: Callable,
dist_th: float,
verbose: bool = False) -> DetectionMetricData:
"""
Average Precision over predefined different recall thresholds for a single distance threshold.
The recall/conf thresholds and other raw metrics will be used in secondary metrics.
:param gt_boxes: Maps every sample_token to a list of its sample_annotations.
:param pred_boxes: Maps every sample_token to a list of its sample_results.
:param class_name: Class to compute AP on.
:param dist_fcn: Distance function used to match detections and ground truths.
:param dist_th: Distance threshold for a match.
:param verbose: If true, print debug messages.
:return: (average_prec, metrics). The average precision value and raw data for a number of metrics.
"""
# ---------------------------------------------
# Organize input and initialize accumulators.
# ---------------------------------------------
# Count the positives.
npos = len([1 for gt_box in gt_boxes.all if gt_box.detection_name == class_name])
if verbose:
print("Found {} GT of class {} out of {} total across {} samples.".
format(npos, class_name, len(gt_boxes.all), len(gt_boxes.sample_tokens)))
# For missing classes in the GT, return a data structure corresponding to no predictions.
if npos == 0:
return DetectionMetricData.no_predictions(), 0
# Organize the predictions in a single list.
pred_boxes_list = [box for box in pred_boxes.all if box.detection_name == class_name]
pred_confs = [box.detection_score for box in pred_boxes_list]
if verbose:
print("Found {} PRED of class {} out of {} total across {} samples.".
format(len(pred_confs), class_name, len(pred_boxes.all), len(pred_boxes.sample_tokens)))
# Sort by confidence.
sortind = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(pred_confs))][::-1]
# Do the actual matching.
tp = [] # Accumulator of true positives.
fp = [] # Accumulator of false positives.
conf = [] # Accumulator of confidences.
hit = 0 # Accumulator of matched and hit
# match_data holds the extra metrics we calculate for each match.
match_data = {'conf': [],
'min_ade': [],
'min_fde': [],
'miss_rate': []}
# ---------------------------------------------
# Match and accumulate match data.
# ---------------------------------------------
taken = set() # Initially no gt bounding box is matched.
for ind in sortind:
pred_box = pred_boxes_list[ind]
min_dist = np.inf
match_gt_idx = None
for gt_idx, gt_box in enumerate(gt_boxes[pred_box.sample_token]):
# Find closest match among ground truth boxes
if gt_box.detection_name == class_name and not (pred_box.sample_token, gt_idx) in taken:
this_distance = dist_fcn(gt_box, pred_box)
if this_distance < min_dist:
min_dist = this_distance
match_gt_idx = gt_idx
# If the closest match is close enough according to threshold we have a match!
is_match = min_dist < dist_th
if is_match:
taken.add((pred_box.sample_token, match_gt_idx))
# Update tp, fp and confs.
tp.append(1)
fp.append(0)
conf.append(pred_box.detection_score)
# Since it is a match, update match data also.
gt_box_match = gt_boxes[pred_box.sample_token][match_gt_idx]
match_data['conf'].append(pred_box.detection_score)
minade, minfde, mr = prediction_metrics(gt_box_match, pred_box)
match_data['min_ade'].append(minade)
match_data['min_fde'].append(minfde)
match_data['miss_rate'].append(mr)
if minfde < 2.0:
hit += 1
else:
# No match. Mark this as a false positive.
tp.append(0)
fp.append(1)
conf.append(pred_box.detection_score)
# Check if we have any matches. If not, just return a "no predictions" array.
if len(match_data['min_ade']) == 0:
return MotionMetricData.no_predictions()
# Accumulate.
N_tp = np.sum(tp)
N_fp = np.sum(fp)
tp = np.cumsum(tp).astype(float)
fp = np.cumsum(fp).astype(float)
conf = np.array(conf)
# Calculate precision and recall.
prec = tp / (fp + tp)
rec = tp / float(npos)
rec_interp = np.linspace(0, 1, DetectionMetricData.nelem) # 101 steps, from 0% to 100% recall.
prec = np.interp(rec_interp, rec, prec, right=0)
conf = np.interp(rec_interp, rec, conf, right=0)
rec = rec_interp
# ---------------------------------------------
# Re-sample the match-data to match, prec, recall and conf.
# ---------------------------------------------
for key in match_data.keys():
if key == "conf":
continue # Confidence is used as reference to align with fp and tp. So skip in this step.
else:
# For each match_data, we first calculate the accumulated mean.
tmp = cummean(np.array(match_data[key]))
# Then interpolate based on the confidences. (Note reversing since np.interp needs increasing arrays)
match_data[key] = np.interp(conf[::-1], match_data['conf'][::-1], tmp[::-1])[::-1]
EPA = (hit - 0.5 * N_fp) / npos
## match based on traj
traj_matched = 0
taken = set() # Initially no gt bounding box is matched.
for ind in sortind:
pred_box = pred_boxes_list[ind]
min_dist = np.inf
match_gt_idx = None
for gt_idx, gt_box in enumerate(gt_boxes[pred_box.sample_token]):
# Find closest match among ground truth boxes
if gt_box.detection_name == class_name and not (pred_box.sample_token, gt_idx) in taken:
this_distance = dist_fcn(gt_box, pred_box)
if this_distance < min_dist:
min_dist = this_distance
match_gt_idx = gt_idx
fde_distance = traj_fde(gt_box, pred_box, final_step=12)
# If the closest match is close enough according to threshold we have a match!
is_match = min_dist < dist_th and fde_distance < 2.0
if is_match:
taken.add((pred_box.sample_token, match_gt_idx))
traj_matched += 1
EPA_ = (traj_matched - 0.5 * N_fp) / npos ## same as UniAD
# ---------------------------------------------
# Done. Instantiate MetricData and return
# ---------------------------------------------
return MotionMetricData(recall=rec,
precision=prec,
confidence=conf,
min_ade_err=match_data['min_ade'],
min_fde_err=match_data['min_fde'],
miss_rate_err=match_data['miss_rate']), EPA, EPA_
def prediction_metrics(gt_box_match, pred_box, miss_thresh=2):
gt_traj = np.array(gt_box_match.traj)
pred_traj = np.array(pred_box.traj)
valid_step = gt_traj.shape[0]
if valid_step <= 0:
return 0, 0, 0
pred_traj_valid = pred_traj[:, :valid_step, :]
dist = np.linalg.norm(pred_traj_valid - gt_traj[np.newaxis], axis=2)
minade = dist.mean(axis=1).min()
minfde = dist[:, -1].min()
mr = dist.max(axis=1).min() > miss_thresh
return minade, minfde, mr
def traj_fde(gt_box, pred_box, final_step):
if gt_box.traj.shape[0] <= 0:
return np.inf
final_step = min(gt_box.traj.shape[0], final_step)
gt_final = gt_box.traj[None, final_step-1]
pred_final = np.array(pred_box.traj)[:,final_step-1,:]
err = gt_final - pred_final
err = np.sqrt(np.sum(np.square(gt_final - pred_final), axis=-1))
return np.min(err)
class MotionMetricDataList(DetectionMetricDataList):
""" This stores a set of MetricData in a dict indexed by (name, match-distance). """
@classmethod
def deserialize(cls, content: dict):
mdl = cls()
for key, md in content.items():
name, distance = key.split(':')
mdl.set(name, float(distance), MotionMetricData.deserialize(md))
return mdl
class MotionMetricData(DetectionMetricData):
""" This class holds accumulated and interpolated data required to calculate the detection metrics. """
nelem = 101
def __init__(self,
recall: np.array,
precision: np.array,
confidence: np.array,
min_ade_err: np.array,
min_fde_err: np.array,
miss_rate_err: np.array):
# Assert lengths.
assert len(recall) == self.nelem
assert len(precision) == self.nelem
assert len(confidence) == self.nelem
assert len(min_ade_err) == self.nelem
assert len(min_fde_err) == self.nelem
assert len(miss_rate_err) == self.nelem
# Assert ordering.
assert all(confidence == sorted(confidence, reverse=True)) # Confidences should be descending.
assert all(recall == sorted(recall)) # Recalls should be ascending.
# Set attributes explicitly to help IDEs figure out what is going on.
self.recall = recall
self.precision = precision
self.confidence = confidence
self.min_ade_err = min_ade_err
self.min_fde_err = min_fde_err
self.miss_rate_err = miss_rate_err
def __eq__(self, other):
eq = True
for key in self.serialize().keys():
eq = eq and np.array_equal(getattr(self, key), getattr(other, key))
return eq
@property
def max_recall_ind(self):
""" Returns index of max recall achieved. """
# Last instance of confidence > 0 is index of max achieved recall.
non_zero = np.nonzero(self.confidence)[0]
if len(non_zero) == 0: # If there are no matches, all the confidence values will be zero.
max_recall_ind = 0
else:
max_recall_ind = non_zero[-1]
return max_recall_ind
@property
def max_recall(self):
""" Returns max recall achieved. """
return self.recall[self.max_recall_ind]
def serialize(self):
""" Serialize instance into json-friendly format. """
return {
'recall': self.recall.tolist(),
'precision': self.precision.tolist(),
'confidence': self.confidence.tolist(),
'min_ade_err': self.min_ade_err.tolist(),
'min_fde_err': self.min_fde_err.tolist(),
'miss_rate_err': self.miss_rate_err.tolist(),
}
@classmethod
def deserialize(cls, content: dict):
""" Initialize from serialized content. """
return cls(recall=np.array(content['recall']),
precision=np.array(content['precision']),
confidence=np.array(content['confidence']),
min_ade_err=np.array(content['min_ade_err']),
min_fde_err=np.array(content['min_fde_err']),
miss_rate_err=np.array(content['miss_rate_err']))
@classmethod
def no_predictions(cls):
""" Returns a md instance corresponding to having no predictions. """
return cls(recall=np.linspace(0, 1, cls.nelem),
precision=np.zeros(cls.nelem),
confidence=np.zeros(cls.nelem),
min_ade_err=np.ones(cls.nelem),
min_fde_err=np.ones(cls.nelem),
miss_rate_err=np.ones(cls.nelem))
@classmethod
def random_md(cls):
""" Returns an md instance corresponding to a random results. """
return cls(recall=np.linspace(0, 1, cls.nelem),
precision=np.random.random(cls.nelem),
confidence=np.linspace(0, 1, cls.nelem)[::-1],
min_ade_err=np.random.random(cls.nelem),
min_fde_err=np.random.random(cls.nelem),
miss_rate_err=np.random.random(cls.nelem))
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
from shapely.geometry import Polygon
from mmcv.utils import print_log
from mmdet.datasets import build_dataset, build_dataloader
from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners
def check_collision(ego_box, boxes):
'''
ego_box: tensor with shape [7], [x, y, z, w, l, h, yaw]
boxes: tensor with shape [N, 7]
'''
if boxes.shape[0] == 0:
return False
# follow uniad, add a 0.5m offset
ego_box[0] += 0.5 * torch.cos(ego_box[6])
ego_box[1] += 0.5 * torch.sin(ego_box[6])
ego_corners_box = box3d_to_corners(ego_box.unsqueeze(0))[0, [0, 3, 7, 4], :2]
corners_box = box3d_to_corners(boxes)[:, [0, 3, 7, 4], :2]
ego_poly = Polygon([(point[0], point[1]) for point in ego_corners_box])
for i in range(len(corners_box)):
box_poly = Polygon([(point[0], point[1]) for point in corners_box[i]])
collision = ego_poly.intersects(box_poly)
if collision:
return True
return False
def get_yaw(traj):
start = traj[0]
end = traj[-1]
dist = torch.linalg.norm(end - start, dim=-1)
if dist < 0.5:
return traj.new_ones(traj.shape[0]) * np.pi / 2
zeros = traj.new_zeros((1, 2))
traj_cat = torch.cat([zeros, traj], dim=0)
yaw = traj.new_zeros(traj.shape[0]+1)
yaw[..., 1:-1] = torch.atan2(
traj_cat[..., 2:, 1] - traj_cat[..., :-2, 1],
traj_cat[..., 2:, 0] - traj_cat[..., :-2, 0],
)
yaw[..., -1] = torch.atan2(
traj_cat[..., -1, 1] - traj_cat[..., -2, 1],
traj_cat[..., -1, 0] - traj_cat[..., -2, 0],
)
return yaw[1:]
class PlanningMetric():
def __init__(
self,
n_future=6,
compute_on_step: bool = False,
):
self.W = 1.85
self.H = 4.084
self.n_future = n_future
self.reset()
def reset(self):
self.obj_col = torch.zeros(self.n_future)
self.obj_box_col = torch.zeros(self.n_future)
self.L2 = torch.zeros(self.n_future)
self.total = torch.tensor(0)
def evaluate_single_coll(self, traj, fut_boxes):
n_future = traj.shape[0]
yaw = get_yaw(traj)
ego_box = traj.new_zeros((n_future, 7))
ego_box[:, :2] = traj
ego_box[:, 3:6] = ego_box.new_tensor([self.H, self.W, 1.56])
ego_box[:, 6] = yaw
collision = torch.zeros(n_future, dtype=torch.bool)
for t in range(n_future):
ego_box_t = ego_box[t].clone()
boxes = fut_boxes[t][0].clone()
collision[t] = check_collision(ego_box_t, boxes)
return collision
def evaluate_coll(self, trajs, gt_trajs, fut_boxes):
B, n_future, _ = trajs.shape
trajs = trajs * torch.tensor([-1, 1], device=trajs.device)
gt_trajs = gt_trajs * torch.tensor([-1, 1], device=gt_trajs.device)
obj_coll_sum = torch.zeros(n_future, device=trajs.device)
obj_box_coll_sum = torch.zeros(n_future, device=trajs.device)
assert B == 1, 'only supprt bs=1'
for i in range(B):
gt_box_coll = self.evaluate_single_coll(gt_trajs[i], fut_boxes)
box_coll = self.evaluate_single_coll(trajs[i], fut_boxes)
box_coll = torch.logical_and(box_coll, torch.logical_not(gt_box_coll))
obj_coll_sum += gt_box_coll.long()
obj_box_coll_sum += box_coll.long()
return obj_coll_sum, obj_box_coll_sum
def compute_L2(self, trajs, gt_trajs, gt_trajs_mask):
'''
trajs: torch.Tensor (B, n_future, 3)
gt_trajs: torch.Tensor (B, n_future, 3)
'''
return torch.sqrt((((trajs[:, :, :2] - gt_trajs[:, :, :2]) ** 2) * gt_trajs_mask).sum(dim=-1))
def update(self, trajs, gt_trajs, gt_trajs_mask, fut_boxes):
assert trajs.shape == gt_trajs.shape
trajs[..., 0] = - trajs[..., 0]
gt_trajs[..., 0] = - gt_trajs[..., 0]
L2 = self.compute_L2(trajs, gt_trajs, gt_trajs_mask)
obj_coll_sum, obj_box_coll_sum = self.evaluate_coll(trajs[:,:,:2], gt_trajs[:,:,:2], fut_boxes)
self.obj_col += obj_coll_sum
self.obj_box_col += obj_box_coll_sum
self.L2 += L2.sum(dim=0)
self.total +=len(trajs)
def compute(self):
return {
'obj_col': self.obj_col / self.total,
'obj_box_col': self.obj_box_col / self.total,
'L2' : self.L2 / self.total
}
def planning_eval(results, eval_config, logger):
dataset = build_dataset(eval_config)
dataloader = build_dataloader(
dataset, samples_per_gpu=1, workers_per_gpu=1, shuffle=False, dist=False)
planning_metrics = PlanningMetric()
for i, data in enumerate(tqdm(dataloader)):
sdc_planning = data['gt_ego_fut_trajs'].cumsum(dim=-2).unsqueeze(1)
sdc_planning_mask = data['gt_ego_fut_masks'].unsqueeze(-1).repeat(1, 1, 2).unsqueeze(1)
command = data['gt_ego_fut_cmd'].argmax(dim=-1).item()
fut_boxes = data['fut_boxes']
if not sdc_planning_mask.all(): ## for incomplete gt, we do not count this sample
continue
res = results[i]
pred_sdc_traj = res['img_bbox']['final_planning'].unsqueeze(0)
planning_metrics.update(pred_sdc_traj[:, :6, :2], sdc_planning[0,:, :6, :2], sdc_planning_mask[0,:, :6, :2], fut_boxes)
planning_results = planning_metrics.compute()
planning_metrics.reset()
from prettytable import PrettyTable
planning_tab = PrettyTable()
metric_dict = {}
planning_tab.field_names = [
"metrics", "0.5s", "1.0s", "1.5s", "2.0s", "2.5s", "3.0s", "avg"]
for key in planning_results.keys():
value = planning_results[key].tolist()
new_values = []
for i in range(len(value)):
new_values.append(np.array(value[:i+1]).mean())
value = new_values
avg = [value[1], value[3], value[5]]
avg = sum(avg) / len(avg)
value.append(avg)
metric_dict[key] = avg
row_value = []
row_value.append(key)
for i in range(len(value)):
if 'col' in key:
row_value.append('%.3f' % float(value[i]*100) + '%')
else:
row_value.append('%.4f' % float(value[i]))
planning_tab.add_row(row_value)
print_log('\n'+str(planning_tab), logger=logger)
return metric_dict
from shapely.geometry import LineString, box, Polygon
from shapely import ops, strtree
import numpy as np
from nuscenes.map_expansion.map_api import NuScenesMap, NuScenesMapExplorer
from nuscenes.eval.common.utils import quaternion_yaw
from pyquaternion import Quaternion
from .utils import split_collections, get_drivable_area_contour, \
get_ped_crossing_contour
from numpy.typing import NDArray
from typing import Dict, List, Tuple, Union
class NuscMapExtractor(object):
"""NuScenes map ground-truth extractor.
Args:
data_root (str): path to nuScenes dataset
roi_size (tuple or list): bev range
"""
def __init__(self, data_root: str, roi_size: Union[List, Tuple]) -> None:
self.roi_size = roi_size
self.MAPS = ['boston-seaport', 'singapore-hollandvillage',
'singapore-onenorth', 'singapore-queenstown']
self.nusc_maps = {}
self.map_explorer = {}
for loc in self.MAPS:
self.nusc_maps[loc] = NuScenesMap(
dataroot=data_root, map_name=loc)
self.map_explorer[loc] = NuScenesMapExplorer(self.nusc_maps[loc])
# local patch in nuScenes format
self.local_patch = box(-roi_size[0] / 2, -roi_size[1] / 2,
roi_size[0] / 2, roi_size[1] / 2)
def _union_ped(self, ped_geoms: List[Polygon]) -> List[Polygon]:
''' merge close ped crossings.
Args:
ped_geoms (list): list of Polygon
Returns:
union_ped_geoms (Dict): merged ped crossings
'''
def get_rec_direction(geom):
rect = geom.minimum_rotated_rectangle
rect_v_p = np.array(rect.exterior.coords)[:3]
rect_v = rect_v_p[1:]-rect_v_p[:-1]
v_len = np.linalg.norm(rect_v, axis=-1)
longest_v_i = v_len.argmax()
return rect_v[longest_v_i], v_len[longest_v_i]
tree = strtree.STRtree(ped_geoms)
index_by_id = dict((id(pt), i) for i, pt in enumerate(ped_geoms))
final_pgeom = []
remain_idx = [i for i in range(len(ped_geoms))]
for i, pgeom in enumerate(ped_geoms):
if i not in remain_idx:
continue
# update
remain_idx.pop(remain_idx.index(i))
pgeom_v, pgeom_v_norm = get_rec_direction(pgeom)
final_pgeom.append(pgeom)
for o in tree.query(pgeom):
o_idx = index_by_id[id(o)]
if o_idx not in remain_idx:
continue
o_v, o_v_norm = get_rec_direction(o)
cos = pgeom_v.dot(o_v)/(pgeom_v_norm*o_v_norm)
if 1 - np.abs(cos) < 0.01: # theta < 8 degrees.
final_pgeom[-1] =\
final_pgeom[-1].union(o)
# update
remain_idx.pop(remain_idx.index(o_idx))
results = []
for p in final_pgeom:
results.extend(split_collections(p))
return results
def get_map_geom(self,
location: str,
translation: Union[List, NDArray],
rotation: Union[List, NDArray]) -> Dict[str, List[Union[LineString, Polygon]]]:
''' Extract geometries given `location` and self pose, self may be lidar or ego.
Args:
location (str): city name
translation (array): self2global translation, shape (3,)
rotation (array): self2global quaternion, shape (4, )
Returns:
geometries (Dict): extracted geometries by category.
'''
# (center_x, center_y, len_y, len_x) in nuscenes format
patch_box = (translation[0], translation[1],
self.roi_size[1], self.roi_size[0])
rotation = Quaternion(rotation)
yaw = quaternion_yaw(rotation) / np.pi * 180
# get dividers
lane_dividers = self.map_explorer[location]._get_layer_line(
patch_box, yaw, 'lane_divider')
road_dividers = self.map_explorer[location]._get_layer_line(
patch_box, yaw, 'road_divider')
all_dividers = []
for line in lane_dividers + road_dividers:
all_dividers += split_collections(line)
# get ped crossings
ped_crossings = []
ped = self.map_explorer[location]._get_layer_polygon(
patch_box, yaw, 'ped_crossing')
for p in ped:
ped_crossings += split_collections(p)
# some ped crossings are split into several small parts
# we need to merge them
ped_crossings = self._union_ped(ped_crossings)
ped_crossing_lines = []
for p in ped_crossings:
# extract exteriors to get a closed polyline
line = get_ped_crossing_contour(p, self.local_patch)
if line is not None:
ped_crossing_lines.append(line)
# get boundaries
# we take the union of road segments and lanes as drivable areas
# we don't take drivable area layer in nuScenes since its definition may be ambiguous
road_segments = self.map_explorer[location]._get_layer_polygon(
patch_box, yaw, 'road_segment')
lanes = self.map_explorer[location]._get_layer_polygon(
patch_box, yaw, 'lane')
union_roads = ops.unary_union(road_segments)
union_lanes = ops.unary_union(lanes)
drivable_areas = ops.unary_union([union_roads, union_lanes])
drivable_areas = split_collections(drivable_areas)
# boundaries are defined as the contour of drivable areas
boundaries = get_drivable_area_contour(drivable_areas, self.roi_size)
return dict(
divider=all_dividers, # List[LineString]
ped_crossing=ped_crossing_lines, # List[LineString]
boundary=boundaries, # List[LineString]
drivable_area=drivable_areas, # List[Polygon],
)
from shapely.geometry import LineString, box, Polygon, LinearRing
from shapely.geometry.base import BaseGeometry
from shapely import ops
import numpy as np
from scipy.spatial import distance
from typing import List, Optional, Tuple
from numpy.typing import NDArray
def split_collections(geom: BaseGeometry) -> List[Optional[BaseGeometry]]:
''' Split Multi-geoms to list and check is valid or is empty.
Args:
geom (BaseGeometry): geoms to be split or validate.
Returns:
geometries (List): list of geometries.
'''
assert geom.geom_type in ['MultiLineString', 'LineString', 'MultiPolygon',
'Polygon', 'GeometryCollection'], f"got geom type {geom.geom_type}"
if 'Multi' in geom.geom_type:
outs = []
for g in geom.geoms:
if g.is_valid and not g.is_empty:
outs.append(g)
return outs
else:
if geom.is_valid and not geom.is_empty:
return [geom,]
else:
return []
def get_drivable_area_contour(drivable_areas: List[Polygon],
roi_size: Tuple) -> List[LineString]:
''' Extract drivable area contours to get list of boundaries.
Args:
drivable_areas (list): list of drivable areas.
roi_size (tuple): bev range size
Returns:
boundaries (List): list of boundaries.
'''
max_x = roi_size[0] / 2
max_y = roi_size[1] / 2
# a bit smaller than roi to avoid unexpected boundaries on edges
local_patch = box(-max_x + 0.2, -max_y + 0.2, max_x - 0.2, max_y - 0.2)
exteriors = []
interiors = []
for poly in drivable_areas:
exteriors.append(poly.exterior)
for inter in poly.interiors:
interiors.append(inter)
results = []
for ext in exteriors:
# NOTE: we make sure all exteriors are clock-wise
# such that each boundary's right-hand-side is drivable area
# and left-hand-side is walk way
if ext.is_ccw:
ext = LinearRing(list(ext.coords)[::-1])
lines = ext.intersection(local_patch)
if lines.geom_type == 'MultiLineString':
lines = ops.linemerge(lines)
assert lines.geom_type in ['MultiLineString', 'LineString']
results.extend(split_collections(lines))
for inter in interiors:
# NOTE: we make sure all interiors are counter-clock-wise
if not inter.is_ccw:
inter = LinearRing(list(inter.coords)[::-1])
lines = inter.intersection(local_patch)
if lines.geom_type == 'MultiLineString':
lines = ops.linemerge(lines)
assert lines.geom_type in ['MultiLineString', 'LineString']
results.extend(split_collections(lines))
return results
def get_ped_crossing_contour(polygon: Polygon,
local_patch: box) -> Optional[LineString]:
''' Extract ped crossing contours to get a closed polyline.
Different from `get_drivable_area_contour`, this function ensures a closed polyline.
Args:
polygon (Polygon): ped crossing polygon to be extracted.
local_patch (tuple): local patch params
Returns:
line (LineString): a closed line
'''
ext = polygon.exterior
if not ext.is_ccw:
ext = LinearRing(list(ext.coords)[::-1])
lines = ext.intersection(local_patch)
if lines.type != 'LineString':
# remove points in intersection results
lines = [l for l in lines.geoms if l.geom_type != 'Point']
lines = ops.linemerge(lines)
# same instance but not connected.
if lines.type != 'LineString':
ls = []
for l in lines.geoms:
ls.append(np.array(l.coords))
lines = np.concatenate(ls, axis=0)
lines = LineString(lines)
if not lines.is_empty:
return lines
return None
import random
import math
import os
from os import path as osp
import cv2
import tempfile
import copy
import prettytable
import numpy as np
import torch
from torch.utils.data import Dataset
import pyquaternion
from shapely.geometry import LineString
from nuscenes.utils.data_classes import Box as NuScenesBox
from nuscenes.eval.detection.config import config_factory as det_configs
from nuscenes.eval.common.config import config_factory as track_configs
import mmcv
from mmcv.utils import print_log
from mmdet.datasets import DATASETS
from mmdet.datasets.pipelines import Compose
from .utils import (
draw_lidar_bbox3d_on_img,
draw_lidar_bbox3d_on_bev,
)
@DATASETS.register_module()
class NuScenes3DDataset(Dataset):
DefaultAttribute = {
"car": "vehicle.parked",
"pedestrian": "pedestrian.moving",
"trailer": "vehicle.parked",
"truck": "vehicle.parked",
"bus": "vehicle.moving",
"motorcycle": "cycle.without_rider",
"construction_vehicle": "vehicle.parked",
"bicycle": "cycle.without_rider",
"barrier": "",
"traffic_cone": "",
}
ErrNameMapping = {
"trans_err": "mATE",
"scale_err": "mASE",
"orient_err": "mAOE",
"vel_err": "mAVE",
"attr_err": "mAAE",
}
CLASSES = (
"car",
"truck",
"trailer",
"bus",
"construction_vehicle",
"bicycle",
"motorcycle",
"pedestrian",
"traffic_cone",
"barrier",
)
MAP_CLASSES = (
'ped_crossing',
'divider',
'boundary',
)
ID_COLOR_MAP = [
(59, 59, 238),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(0, 255, 255),
(255, 0, 255),
(255, 255, 255),
(0, 127, 255),
(71, 130, 255),
(127, 127, 0),
]
def __init__(
self,
ann_file,
pipeline=None,
data_root=None,
classes=None,
map_classes=None,
load_interval=1,
with_velocity=True,
modality=None,
test_mode=False,
det3d_eval_version="detection_cvpr_2019",
track3d_eval_version="tracking_nips_2019",
version="v1.0-trainval",
use_valid_flag=False,
vis_score_threshold=0.25,
data_aug_conf=None,
sequences_split_num=1,
with_seq_flag=False,
keep_consistent_seq_aug=True,
work_dir=None,
eval_config=None,
):
self.version = version
self.load_interval = load_interval
self.use_valid_flag = use_valid_flag
super().__init__()
self.data_root = data_root
self.ann_file = ann_file
self.test_mode = test_mode
self.modality = modality
self.box_mode_3d = 0
if classes is not None:
self.CLASSES = classes
if map_classes is not None:
self.MAP_CLASSES = map_classes
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
self.data_infos = self.load_annotations(self.ann_file)
if pipeline is not None:
self.pipeline = Compose(pipeline)
self.with_velocity = with_velocity
self.det3d_eval_version = det3d_eval_version
self.det3d_eval_configs = det_configs(self.det3d_eval_version)
self.det3d_eval_configs.class_names = list(self.det3d_eval_configs.class_range.keys())
self.track3d_eval_version = track3d_eval_version
self.track3d_eval_configs = track_configs(self.track3d_eval_version)
self.track3d_eval_configs.class_names = list(self.track3d_eval_configs.class_range.keys())
if self.modality is None:
self.modality = dict(
use_camera=False,
use_lidar=True,
use_radar=False,
use_map=False,
use_external=False,
)
self.vis_score_threshold = vis_score_threshold
self.data_aug_conf = data_aug_conf
self.sequences_split_num = sequences_split_num
self.keep_consistent_seq_aug = keep_consistent_seq_aug
if with_seq_flag:
self._set_sequence_group_flag()
self.work_dir = work_dir
self.eval_config = eval_config
def __len__(self):
return len(self.data_infos)
def _set_sequence_group_flag(self):
"""
Set each sequence to be a different group
"""
if self.sequences_split_num == -1:
self.flag = np.arange(len(self.data_infos))
return
res = []
curr_sequence = 0
for idx in range(len(self.data_infos)):
if idx != 0 and len(self.data_infos[idx]["sweeps"]) == 0:
# Not first frame and # of sweeps is 0 -> new sequence
curr_sequence += 1
res.append(curr_sequence)
self.flag = np.array(res, dtype=np.int64)
if self.sequences_split_num != 1:
if self.sequences_split_num == "all":
self.flag = np.array(
range(len(self.data_infos)), dtype=np.int64
)
else:
bin_counts = np.bincount(self.flag)
new_flags = []
curr_new_flag = 0
for curr_flag in range(len(bin_counts)):
curr_sequence_length = np.array(
list(
range(
0,
bin_counts[curr_flag],
math.ceil(
bin_counts[curr_flag]
/ self.sequences_split_num
),
)
)
+ [bin_counts[curr_flag]]
)
for sub_seq_idx in (
curr_sequence_length[1:] - curr_sequence_length[:-1]
):
for _ in range(sub_seq_idx):
new_flags.append(curr_new_flag)
curr_new_flag += 1
assert len(new_flags) == len(self.flag)
assert (
len(np.bincount(new_flags))
== len(np.bincount(self.flag)) * self.sequences_split_num
)
self.flag = np.array(new_flags, dtype=np.int64)
def get_augmentation(self):
if self.data_aug_conf is None:
return None
H, W = self.data_aug_conf["H"], self.data_aug_conf["W"]
fH, fW = self.data_aug_conf["final_dim"]
if not self.test_mode:
resize = np.random.uniform(*self.data_aug_conf["resize_lim"])
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = (
int(
(1 - np.random.uniform(*self.data_aug_conf["bot_pct_lim"]))
* newH
)
- fH
)
crop_w = int(np.random.uniform(0, max(0, newW - fW)))
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False
if self.data_aug_conf["rand_flip"] and np.random.choice([0, 1]):
flip = True
rotate = np.random.uniform(*self.data_aug_conf["rot_lim"])
rotate_3d = np.random.uniform(*self.data_aug_conf["rot3d_range"])
else:
resize = max(fH / H, fW / W)
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = (
int((1 - np.mean(self.data_aug_conf["bot_pct_lim"])) * newH)
- fH
)
crop_w = int(max(0, newW - fW) / 2)
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False
rotate = 0
rotate_3d = 0
aug_config = {
"resize": resize,
"resize_dims": resize_dims,
"crop": crop,
"flip": flip,
"rotate": rotate,
"rotate_3d": rotate_3d,
}
return aug_config
def __getitem__(self, idx):
if isinstance(idx, dict):
aug_config = idx["aug_config"]
idx = idx["idx"]
else:
aug_config = self.get_augmentation()
data = self.get_data_info(idx)
data["aug_config"] = aug_config
data = self.pipeline(data)
return data
def get_cat_ids(self, idx):
info = self.data_infos[idx]
if self.use_valid_flag:
mask = info["valid_flag"]
gt_names = set(info["gt_names"][mask])
else:
gt_names = set(info["gt_names"])
cat_ids = []
for name in gt_names:
if name in self.CLASSES:
cat_ids.append(self.cat2id[name])
return cat_ids
def load_annotations(self, ann_file):
data = mmcv.load(ann_file, file_format="pkl")
data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"]))
data_infos = data_infos[:: self.load_interval]
self.metadata = data["metadata"]
self.version = self.metadata["version"]
print(self.metadata)
return data_infos
def anno2geom(self, annos):
map_geoms = {}
for label, anno_list in annos.items():
map_geoms[label] = []
for anno in anno_list:
geom = LineString(anno)
map_geoms[label].append(geom)
return map_geoms
def get_data_info(self, index):
info = self.data_infos[index]
input_dict = dict(
token=info["token"],
map_location=info["map_location"],
pts_filename=info["lidar_path"],
sweeps=info["sweeps"],
timestamp=info["timestamp"] / 1e6,
lidar2ego_translation=info["lidar2ego_translation"],
lidar2ego_rotation=info["lidar2ego_rotation"],
ego2global_translation=info["ego2global_translation"],
ego2global_rotation=info["ego2global_rotation"],
ego_status=info['ego_status'].astype(np.float32),
map_infos=info["map_annos"],
)
lidar2ego = np.eye(4)
lidar2ego[:3, :3] = pyquaternion.Quaternion(
info["lidar2ego_rotation"]
).rotation_matrix
lidar2ego[:3, 3] = np.array(info["lidar2ego_translation"])
ego2global = np.eye(4)
ego2global[:3, :3] = pyquaternion.Quaternion(
info["ego2global_rotation"]
).rotation_matrix
ego2global[:3, 3] = np.array(info["ego2global_translation"])
input_dict["lidar2global"] = ego2global @ lidar2ego
map_geoms = self.anno2geom(info["map_annos"])
input_dict["map_geoms"] = map_geoms
if self.modality["use_camera"]:
image_paths = []
lidar2img_rts = []
lidar2cam_rts = []
cam_intrinsic = []
for cam_type, cam_info in info["cams"].items():
image_paths.append(cam_info["data_path"])
# obtain lidar to image transformation matrix
lidar2cam_r = np.linalg.inv(cam_info["sensor2lidar_rotation"])
lidar2cam_t = (
cam_info["sensor2lidar_translation"] @ lidar2cam_r.T
)
lidar2cam_rt = np.eye(4)
lidar2cam_rt[:3, :3] = lidar2cam_r.T
lidar2cam_rt[3, :3] = -lidar2cam_t
intrinsic = copy.deepcopy(cam_info["cam_intrinsic"])
cam_intrinsic.append(intrinsic)
viewpad = np.eye(4)
viewpad[: intrinsic.shape[0], : intrinsic.shape[1]] = intrinsic
lidar2img_rt = viewpad @ lidar2cam_rt.T
lidar2img_rts.append(lidar2img_rt)
lidar2cam_rts.append(lidar2cam_rt)
input_dict.update(
dict(
img_filename=image_paths,
lidar2img=lidar2img_rts,
lidar2cam=lidar2cam_rts,
cam_intrinsic=cam_intrinsic,
)
)
annos = self.get_ann_info(index)
input_dict.update(annos)
return input_dict
def get_ann_info(self, index):
info = self.data_infos[index]
if self.use_valid_flag:
mask = info["valid_flag"]
else:
mask = info["num_lidar_pts"] > 0
gt_bboxes_3d = info["gt_boxes"][mask]
gt_names_3d = info["gt_names"][mask]
gt_labels_3d = []
for cat in gt_names_3d:
if cat in self.CLASSES:
gt_labels_3d.append(self.CLASSES.index(cat))
else:
gt_labels_3d.append(-1)
gt_labels_3d = np.array(gt_labels_3d)
if self.with_velocity:
gt_velocity = info["gt_velocity"][mask]
nan_mask = np.isnan(gt_velocity[:, 0])
gt_velocity[nan_mask] = [0.0, 0.0]
gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocity], axis=-1)
anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
gt_names=gt_names_3d,
)
if "instance_inds" in info:
instance_inds = np.array(info["instance_inds"], dtype=np.int)[mask]
anns_results["instance_inds"] = instance_inds
if 'gt_agent_fut_trajs' in info:
anns_results['gt_agent_fut_trajs'] = info['gt_agent_fut_trajs'][mask]
anns_results['gt_agent_fut_masks'] = info['gt_agent_fut_masks'][mask]
if 'gt_ego_fut_trajs' in info:
anns_results['gt_ego_fut_trajs'] = info['gt_ego_fut_trajs']
anns_results['gt_ego_fut_masks'] = info['gt_ego_fut_masks']
anns_results['gt_ego_fut_cmd'] = info['gt_ego_fut_cmd']
## get future box for planning eval
fut_ts = int(info['gt_ego_fut_masks'].sum())
fut_boxes = []
cur_scene_token = info["scene_token"]
cur_T_global = get_T_global(info)
for i in range(1, fut_ts + 1):
fut_info = self.data_infos[index + i]
fut_scene_token = fut_info["scene_token"]
if cur_scene_token != fut_scene_token:
break
if self.use_valid_flag:
mask = fut_info["valid_flag"]
else:
mask = fut_info["num_lidar_pts"] > 0
fut_gt_bboxes_3d = fut_info["gt_boxes"][mask]
fut_T_global = get_T_global(fut_info)
T_fut2cur = np.linalg.inv(cur_T_global) @ fut_T_global
center = fut_gt_bboxes_3d[:, :3] @ T_fut2cur[:3, :3].T + T_fut2cur[:3, 3]
yaw = np.stack([np.cos(fut_gt_bboxes_3d[:, 6]), np.sin(fut_gt_bboxes_3d[:, 6])], axis=-1)
yaw = yaw @ T_fut2cur[:2, :2].T
yaw = np.arctan2(yaw[..., 1], yaw[..., 0])
fut_gt_bboxes_3d[:, :3] = center
fut_gt_bboxes_3d[:, 6] = yaw
fut_boxes.append(fut_gt_bboxes_3d)
anns_results['fut_boxes'] = fut_boxes
return anns_results
def _format_bbox(self, results, jsonfile_prefix=None, tracking=False):
nusc_annos = {}
mapped_class_names = self.CLASSES
print("Start to convert detection format...")
for sample_id, det in enumerate(mmcv.track_iter_progress(results)):
annos = []
boxes = output_to_nusc_box(
det, threshold=self.tracking_threshold if tracking else None
)
sample_token = self.data_infos[sample_id]["token"]
boxes = lidar_nusc_box_to_global(
self.data_infos[sample_id],
boxes,
mapped_class_names,
self.det3d_eval_configs,
self.det3d_eval_version,
)
for i, box in enumerate(boxes):
name = mapped_class_names[box.label]
if tracking and name in [
"barrier",
"traffic_cone",
"construction_vehicle",
]:
continue
if np.sqrt(box.velocity[0] ** 2 + box.velocity[1] ** 2) > 0.2:
if name in [
"car",
"construction_vehicle",
"bus",
"truck",
"trailer",
]:
attr = "vehicle.moving"
elif name in ["bicycle", "motorcycle"]:
attr = "cycle.with_rider"
else:
attr = NuScenes3DDataset.DefaultAttribute[name]
else:
if name in ["pedestrian"]:
attr = "pedestrian.standing"
elif name in ["bus"]:
attr = "vehicle.stopped"
else:
attr = NuScenes3DDataset.DefaultAttribute[name]
nusc_anno = dict(
sample_token=sample_token,
translation=box.center.tolist(),
size=box.wlh.tolist(),
rotation=box.orientation.elements.tolist(),
velocity=box.velocity[:2].tolist(),
)
if not tracking:
nusc_anno.update(
dict(
detection_name=name,
detection_score=box.score,
attribute_name=attr,
)
)
else:
nusc_anno.update(
dict(
tracking_name=name,
tracking_score=box.score,
tracking_id=str(box.token),
)
)
annos.append(nusc_anno)
nusc_annos[sample_token] = annos
nusc_submissions = {
"meta": self.modality,
"results": nusc_annos,
}
mmcv.mkdir_or_exist(jsonfile_prefix)
res_path = osp.join(jsonfile_prefix, "results_nusc.json")
print("Results writes to", res_path)
mmcv.dump(nusc_submissions, res_path)
return res_path
def _evaluate_single(
self, result_path, logger=None, result_name="img_bbox", tracking=False
):
from nuscenes import NuScenes
output_dir = osp.join(*osp.split(result_path)[:-1])
nusc = NuScenes(
version=self.version, dataroot=self.data_root, verbose=False
)
eval_set_map = {
"v1.0-mini": "mini_val",
"v1.0-trainval": "val",
}
if not tracking:
from nuscenes.eval.detection.evaluate import NuScenesEval
nusc_eval = NuScenesEval(
nusc,
config=self.det3d_eval_configs,
result_path=result_path,
eval_set=eval_set_map[self.version],
output_dir=output_dir,
verbose=True,
)
nusc_eval.main(render_curves=False)
# record metrics
metrics = mmcv.load(osp.join(output_dir, "metrics_summary.json"))
detail = dict()
metric_prefix = f"{result_name}_NuScenes"
for name in self.CLASSES:
for k, v in metrics["label_aps"][name].items():
val = float("{:.4f}".format(v))
detail[
"{}/{}_AP_dist_{}".format(metric_prefix, name, k)
] = val
for k, v in metrics["label_tp_errors"][name].items():
val = float("{:.4f}".format(v))
detail["{}/{}_{}".format(metric_prefix, name, k)] = val
for k, v in metrics["tp_errors"].items():
val = float("{:.4f}".format(v))
detail[
"{}/{}".format(metric_prefix, self.ErrNameMapping[k])
] = val
detail["{}/NDS".format(metric_prefix)] = metrics["nd_score"]
detail["{}/mAP".format(metric_prefix)] = metrics["mean_ap"]
else:
from nuscenes.eval.tracking.evaluate import TrackingEval
nusc_eval = TrackingEval(
config=self.track3d_eval_configs,
result_path=result_path,
eval_set=eval_set_map[self.version],
output_dir=output_dir,
verbose=True,
nusc_version=self.version,
nusc_dataroot=self.data_root,
)
metrics = nusc_eval.main()
# record metrics
metrics = mmcv.load(osp.join(output_dir, "metrics_summary.json"))
print(metrics)
detail = dict()
metric_prefix = f"{result_name}_NuScenes"
keys = [
"amota",
"amotp",
"recall",
"motar",
"gt",
"mota",
"motp",
"mt",
"ml",
"faf",
"tp",
"fp",
"fn",
"ids",
"frag",
"tid",
"lgd",
]
for key in keys:
detail["{}/{}".format(metric_prefix, key)] = metrics[key]
return detail
def format_results(self, results, jsonfile_prefix=None, tracking=False):
assert isinstance(results, list), "results must be a list"
if jsonfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
jsonfile_prefix = osp.join(tmp_dir.name, "results")
else:
tmp_dir = None
if not ("pts_bbox" in results[0] or "img_bbox" in results[0]):
result_files = self._format_bbox(
results, jsonfile_prefix, tracking=tracking
)
else:
result_files = dict()
for name in results[0]:
print(f"\nFormating bboxes of {name}")
results_ = [out[name] for out in results]
tmp_file_ = jsonfile_prefix
result_files.update(
{
name: self._format_bbox(
results_, tmp_file_, tracking=tracking
)
}
)
return result_files, tmp_dir
def format_map_results(self, results, prefix=None):
submissions = {'results': {},}
for j, pred in enumerate(results):
'''
For each case, the result should be formatted as Dict{'vectors': [], 'scores': [], 'labels': []}
'vectors': List of vector, each vector is a array([[x1, y1], [x2, y2] ...]),
contain all vectors predicted in this sample.
'scores: List of score(float),
contain scores of all instances in this sample.
'labels': List of label(int),
contain labels of all instances in this sample.
'''
if pred is None: # empty prediction
continue
pred = pred['img_bbox']
single_case = {'vectors': [], 'scores': [], 'labels': []}
token = self.data_infos[j]['token']
for i in range(len(pred['scores'])):
score = pred['scores'][i]
label = pred['labels'][i]
vector = pred['vectors'][i]
# A line should have >=2 points
if len(vector) < 2:
continue
single_case['vectors'].append(vector)
single_case['scores'].append(score)
single_case['labels'].append(label)
submissions['results'][token] = single_case
out_path = osp.join(prefix, 'submission_vector.json')
print(f'saving submissions results to {out_path}')
os.makedirs(os.path.dirname(out_path), exist_ok=True)
mmcv.dump(submissions, out_path)
return out_path
def format_motion_results(self, results, jsonfile_prefix=None, tracking=False, thresh=None):
nusc_annos = {}
mapped_class_names = self.CLASSES
print("Start to convert detection format...")
for sample_id, det in enumerate(mmcv.track_iter_progress(results)):
annos = []
boxes = output_to_nusc_box(
det['img_bbox'], threshold=None
)
sample_token = self.data_infos[sample_id]["token"]
boxes = lidar_nusc_box_to_global(
self.data_infos[sample_id],
boxes,
mapped_class_names,
self.det3d_eval_configs,
self.det3d_eval_version,
filter_with_cls_range=False,
)
for i, box in enumerate(boxes):
if thresh is not None and box.score < thresh:
continue
name = mapped_class_names[box.label]
if tracking and name in [
"barrier",
"traffic_cone",
"construction_vehicle",
]:
continue
if np.sqrt(box.velocity[0] ** 2 + box.velocity[1] ** 2) > 0.2:
if name in [
"car",
"construction_vehicle",
"bus",
"truck",
"trailer",
]:
attr = "vehicle.moving"
elif name in ["bicycle", "motorcycle"]:
attr = "cycle.with_rider"
else:
attr = NuScenes3DDataset.DefaultAttribute[name]
else:
if name in ["pedestrian"]:
attr = "pedestrian.standing"
elif name in ["bus"]:
attr = "vehicle.stopped"
else:
attr = NuScenes3DDataset.DefaultAttribute[name]
nusc_anno = dict(
sample_token=sample_token,
translation=box.center.tolist(),
size=box.wlh.tolist(),
rotation=box.orientation.elements.tolist(),
velocity=box.velocity[:2].tolist(),
)
if not tracking:
nusc_anno.update(
dict(
detection_name=name,
detection_score=box.score,
attribute_name=attr,
)
)
else:
nusc_anno.update(
dict(
tracking_name=name,
tracking_score=box.score,
tracking_id=str(box.token),
)
)
nusc_anno.update(
dict(
trajs=det['img_bbox']['trajs_3d'][i].numpy(),
)
)
annos.append(nusc_anno)
nusc_annos[sample_token] = annos
nusc_submissions = {
"meta": self.modality,
"results": nusc_annos,
}
return nusc_submissions
def _evaluate_single_motion(self,
results,
result_path,
logger=None,
metric='bbox',
result_name='pts_bbox'):
"""Evaluation for a single model in nuScenes protocol.
Args:
result_path (str): Path of the result file.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
metric (str): Metric name used for evaluation. Default: 'bbox'.
result_name (str): Result name in the metric prefix.
Default: 'pts_bbox'.
Returns:
dict: Dictionary of evaluation details.
"""
from nuscenes import NuScenes
from .evaluation.motion.motion_eval_uniad import NuScenesEval as NuScenesEvalMotion
output_dir = result_path
nusc = NuScenes(
version=self.version, dataroot=self.data_root, verbose=False)
eval_set_map = {
'v1.0-mini': 'mini_val',
'v1.0-trainval': 'val',
}
nusc_eval = NuScenesEvalMotion(
nusc,
config=copy.deepcopy(self.det3d_eval_configs),
result_path=results,
eval_set=eval_set_map[self.version],
output_dir=output_dir,
verbose=False,
seconds=6)
metrics = nusc_eval.main(render_curves=False)
MOTION_METRICS = ['EPA', 'min_ade_err', 'min_fde_err', 'miss_rate_err']
class_names = ['car', 'pedestrian']
table = prettytable.PrettyTable()
table.field_names = ["class names"] + MOTION_METRICS
for class_name in class_names:
row_data = [class_name]
for m in MOTION_METRICS:
row_data.append('%.4f' % metrics[f'{class_name}_{m}'])
table.add_row(row_data)
print_log('\n'+str(table), logger=logger)
return metrics
def evaluate(
self,
results,
eval_mode,
metric=None,
logger=None,
jsonfile_prefix=None,
result_names=["img_bbox"],
show=False,
out_dir=None,
pipeline=None,
):
res_path = "results.pkl" if "trainval" in self.version else "results_mini.pkl"
res_path = osp.join(self.work_dir, res_path)
print('All Results write to', res_path)
mmcv.dump(results, res_path)
results_dict = dict()
if eval_mode['with_det']:
self.tracking = eval_mode["with_tracking"]
self.tracking_threshold = eval_mode["tracking_threshold"]
for metric in ["detection", "tracking"]:
tracking = metric == "tracking"
if tracking and not self.tracking:
continue
result_files, tmp_dir = self.format_results(
results, jsonfile_prefix=self.work_dir, tracking=tracking
)
if isinstance(result_files, dict):
for name in result_names:
ret_dict = self._evaluate_single(
result_files[name], tracking=tracking
)
results_dict.update(ret_dict)
elif isinstance(result_files, str):
ret_dict = self._evaluate_single(
result_files, tracking=tracking
)
results_dict.update(ret_dict)
if tmp_dir is not None:
tmp_dir.cleanup()
if eval_mode['with_map']:
from .evaluation.map.vector_eval import VectorEvaluate
self.map_evaluator = VectorEvaluate(self.eval_config)
result_path = self.format_map_results(results, prefix=self.work_dir)
map_results_dict = self.map_evaluator.evaluate(result_path, logger=logger)
results_dict.update(map_results_dict)
if eval_mode['with_motion']:
thresh = eval_mode["motion_threshhold"]
result_files = self.format_motion_results(results, jsonfile_prefix=self.work_dir, thresh=thresh)
motion_results_dict = self._evaluate_single_motion(result_files, self.work_dir, logger=logger)
results_dict.update(motion_results_dict)
if eval_mode['with_planning']:
from .evaluation.planning.planning_eval import planning_eval
planning_results_dict = planning_eval(results, self.eval_config, logger=logger)
results_dict.update(planning_results_dict)
if show or out_dir:
self.show(results, save_dir=out_dir, show=show, pipeline=pipeline)
# print main metrics for recording
metric_str = '\n'
if "img_bbox_NuScenes/NDS" in results_dict:
metric_str += f'mAP: {results_dict.get("img_bbox_NuScenes/mAP"):.4f}\n'
metric_str += f'mATE: {results_dict.get("img_bbox_NuScenes/mATE"):.4f}\n'
metric_str += f'mASE: {results_dict.get("img_bbox_NuScenes/mASE"):.4f}\n'
metric_str += f'mAOE: {results_dict.get("img_bbox_NuScenes/mAOE"):.4f}\n'
metric_str += f'mAVE: {results_dict.get("img_bbox_NuScenes/mAVE"):.4f}\n'
metric_str += f'mAAE: {results_dict.get("img_bbox_NuScenes/mAAE"):.4f}\n'
metric_str += f'NDS: {results_dict.get("img_bbox_NuScenes/NDS"):.4f}\n\n'
if "img_bbox_NuScenes/amota" in results_dict:
metric_str += f'AMOTA: {results_dict["img_bbox_NuScenes/amota"]:.4f}\n'
metric_str += f'AMOTP: {results_dict["img_bbox_NuScenes/amotp"]:.4f}\n'
metric_str += f'RECALL: {results_dict["img_bbox_NuScenes/recall"]:.4f}\n'
metric_str += f'MOTAR: {results_dict["img_bbox_NuScenes/motar"]:.4f}\n'
metric_str += f'MOTA: {results_dict["img_bbox_NuScenes/mota"]:.4f}\n'
metric_str += f'MOTP: {results_dict["img_bbox_NuScenes/motp"]:.4f}\n'
metric_str += f'IDS: {results_dict["img_bbox_NuScenes/ids"]}\n\n'
if "mAP_normal" in results_dict:
metric_str += f'ped_crossing= {results_dict["ped_crossing"]:.4f}\n'
metric_str += f'divider= {results_dict["divider"]:.4f}\n'
metric_str += f'boundary= {results_dict["boundary"]:.4f}\n'
metric_str += f'mAP_normal= {results_dict["mAP_normal"]:.4f}\n\n'
if "car_EPA" in results_dict:
metric_str += f'Car / Ped\n'
metric_str += f'epa= {results_dict["car_EPA"]:.4f} / {results_dict["pedestrian_EPA"]:.4f}\n'
metric_str += f'ade= {results_dict["car_min_ade_err"]:.4f} / {results_dict["pedestrian_min_ade_err"]:.4f}\n'
metric_str += f'fde= {results_dict["car_min_fde_err"]:.4f} / {results_dict["pedestrian_min_fde_err"]:.4f}\n'
metric_str += f'mr= {results_dict["car_miss_rate_err"]:.4f} / {results_dict["pedestrian_miss_rate_err"]:.4f}\n\n'
if "L2" in results_dict:
metric_str += f'obj_box_col: {(results_dict["obj_box_col"]*100):.3f}%\n'
metric_str += f'L2: {results_dict["L2"]:.4f}\n\n'
print_log(metric_str, logger=logger)
return results_dict
def show(self, results, save_dir=None, show=False, pipeline=None):
save_dir = "./" if save_dir is None else save_dir
save_dir = os.path.join(save_dir, "visual")
print_log(os.path.abspath(save_dir))
pipeline = Compose(pipeline)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
fourcc = cv2.VideoWriter_fourcc(*"MJPG")
videoWriter = None
for i, result in enumerate(results):
if "img_bbox" in result.keys():
result = result["img_bbox"]
data_info = pipeline(self.get_data_info(i))
imgs = []
raw_imgs = data_info["img"]
lidar2img = data_info["img_metas"].data["lidar2img"]
pred_bboxes_3d = result["boxes_3d"][
result["scores_3d"] > self.vis_score_threshold
]
if "instance_ids" in result and self.tracking:
color = []
for id in result["instance_ids"].cpu().numpy().tolist():
color.append(
self.ID_COLOR_MAP[int(id % len(self.ID_COLOR_MAP))]
)
elif "labels_3d" in result:
color = []
for id in result["labels_3d"].cpu().numpy().tolist():
color.append(self.ID_COLOR_MAP[id])
else:
color = (255, 0, 0)
# ===== draw boxes_3d to images =====
for j, img_origin in enumerate(raw_imgs):
img = img_origin.copy()
if len(pred_bboxes_3d) != 0:
img = draw_lidar_bbox3d_on_img(
pred_bboxes_3d,
img,
lidar2img[j],
img_metas=None,
color=color,
thickness=3,
)
imgs.append(img)
# ===== draw boxes_3d to BEV =====
bev = draw_lidar_bbox3d_on_bev(
pred_bboxes_3d,
bev_size=img.shape[0] * 2,
color=color,
)
# ===== put text and concat =====
for j, name in enumerate(
[
"front",
"front right",
"front left",
"rear",
"rear left",
"rear right",
]
):
imgs[j] = cv2.rectangle(
imgs[j],
(0, 0),
(440, 80),
color=(255, 255, 255),
thickness=-1,
)
w, h = cv2.getTextSize(name, cv2.FONT_HERSHEY_SIMPLEX, 2, 2)[0]
text_x = int(220 - w / 2)
text_y = int(40 + h / 2)
imgs[j] = cv2.putText(
imgs[j],
name,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
2,
(0, 0, 0),
2,
cv2.LINE_AA,
)
image = np.concatenate(
[
np.concatenate([imgs[2], imgs[0], imgs[1]], axis=1),
np.concatenate([imgs[5], imgs[3], imgs[4]], axis=1),
],
axis=0,
)
image = np.concatenate([image, bev], axis=1)
# ===== save video =====
if videoWriter is None:
videoWriter = cv2.VideoWriter(
os.path.join(save_dir, "video.avi"),
fourcc,
7,
image.shape[:2][::-1],
)
cv2.imwrite(os.path.join(save_dir, f"{i}.jpg"), image)
videoWriter.write(image)
videoWriter.release()
def output_to_nusc_box(detection, threshold=None):
box3d = detection["boxes_3d"]
scores = detection["scores_3d"].numpy()
labels = detection["labels_3d"].numpy()
if "instance_ids" in detection:
ids = detection["instance_ids"] # .numpy()
if threshold is not None:
if "cls_scores" in detection:
mask = detection["cls_scores"].numpy() >= threshold
else:
mask = scores >= threshold
box3d = box3d[mask]
scores = scores[mask]
labels = labels[mask]
ids = ids[mask]
if hasattr(box3d, "gravity_center"):
box_gravity_center = box3d.gravity_center.numpy()
box_dims = box3d.dims.numpy()
nus_box_dims = box_dims[:, [1, 0, 2]]
box_yaw = box3d.yaw.numpy()
else:
box3d = box3d.numpy()
box_gravity_center = box3d[..., :3].copy()
box_dims = box3d[..., 3:6].copy()
nus_box_dims = box_dims[..., [1, 0, 2]]
box_yaw = box3d[..., 6].copy()
# TODO: check whether this is necessary
# with dir_offset & dir_limit in the head
# box_yaw = -box_yaw - np.pi / 2
box_list = []
for i in range(len(box3d)):
quat = pyquaternion.Quaternion(axis=[0, 0, 1], radians=box_yaw[i])
if hasattr(box3d, "gravity_center"):
velocity = (*box3d.tensor[i, 7:9], 0.0)
else:
velocity = (*box3d[i, 7:9], 0.0)
box = NuScenesBox(
box_gravity_center[i],
nus_box_dims[i],
quat,
label=labels[i],
score=scores[i],
velocity=velocity,
)
if "instance_ids" in detection:
box.token = ids[i]
box_list.append(box)
return box_list
def lidar_nusc_box_to_global(
info,
boxes,
classes,
eval_configs,
eval_version="detection_cvpr_2019",
filter_with_cls_range=True,
):
box_list = []
for i, box in enumerate(boxes):
# Move box to ego vehicle coord system
box.rotate(pyquaternion.Quaternion(info["lidar2ego_rotation"]))
box.translate(np.array(info["lidar2ego_translation"]))
# filter det in ego.
if filter_with_cls_range:
cls_range_map = eval_configs.class_range
radius = np.linalg.norm(box.center[:2], 2)
det_range = cls_range_map[classes[box.label]]
if radius > det_range:
continue
# Move box to global coord system
box.rotate(pyquaternion.Quaternion(info["ego2global_rotation"]))
box.translate(np.array(info["ego2global_translation"]))
box_list.append(box)
return box_list
def get_T_global(info):
lidar2ego = np.eye(4)
lidar2ego[:3, :3] = pyquaternion.Quaternion(
info["lidar2ego_rotation"]
).rotation_matrix
lidar2ego[:3, 3] = np.array(info["lidar2ego_translation"])
ego2global = np.eye(4)
ego2global[:3, :3] = pyquaternion.Quaternion(
info["ego2global_rotation"]
).rotation_matrix
ego2global[:3, 3] = np.array(info["ego2global_translation"])
return ego2global @ lidar2ego
\ No newline at end of file
from .transform import (
InstanceNameFilter,
CircleObjectRangeFilter,
NormalizeMultiviewImage,
NuScenesSparse4DAdaptor,
MultiScaleDepthMapGenerator,
)
from .augment import (
ResizeCropFlipImage,
BBoxRotation,
PhotoMetricDistortionMultiViewImage,
)
from .loading import LoadMultiViewImageFromFiles, LoadPointsFromFile
from .vectorize import VectorizeMap
__all__ = [
"InstanceNameFilter",
"ResizeCropFlipImage",
"BBoxRotation",
"CircleObjectRangeFilter",
"MultiScaleDepthMapGenerator",
"NormalizeMultiviewImage",
"PhotoMetricDistortionMultiViewImage",
"NuScenesSparse4DAdaptor",
"LoadMultiViewImageFromFiles",
"LoadPointsFromFile",
"VectorizeMap",
]
import torch
import numpy as np
from numpy import random
import mmcv
from mmdet.datasets.builder import PIPELINES
from PIL import Image
@PIPELINES.register_module()
class ResizeCropFlipImage(object):
def __call__(self, results):
aug_config = results.get("aug_config")
if aug_config is None:
return results
imgs = results["img"]
N = len(imgs)
new_imgs = []
for i in range(N):
img, mat = self._img_transform(
np.uint8(imgs[i]), aug_config,
)
new_imgs.append(np.array(img).astype(np.float32))
results["lidar2img"][i] = mat @ results["lidar2img"][i]
if "cam_intrinsic" in results:
results["cam_intrinsic"][i][:3, :3] *= aug_config["resize"]
# results["cam_intrinsic"][i][:3, :3] = (
# mat[:3, :3] @ results["cam_intrinsic"][i][:3, :3]
# )
results["img"] = new_imgs
results["img_shape"] = [x.shape[:2] for x in new_imgs]
return results
def _img_transform(self, img, aug_configs):
H, W = img.shape[:2]
resize = aug_configs.get("resize", 1)
resize_dims = (int(W * resize), int(H * resize))
crop = aug_configs.get("crop", [0, 0, *resize_dims])
flip = aug_configs.get("flip", False)
rotate = aug_configs.get("rotate", 0)
origin_dtype = img.dtype
if origin_dtype != np.uint8:
min_value = img.min()
max_vaule = img.max()
scale = 255 / (max_vaule - min_value)
img = (img - min_value) * scale
img = np.uint8(img)
img = Image.fromarray(img)
img = img.resize(resize_dims).crop(crop)
if flip:
img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
img = img.rotate(rotate)
img = np.array(img).astype(np.float32)
if origin_dtype != np.uint8:
img = img.astype(np.float32)
img = img / scale + min_value
transform_matrix = np.eye(3)
transform_matrix[:2, :2] *= resize
transform_matrix[:2, 2] -= np.array(crop[:2])
if flip:
flip_matrix = np.array(
[[-1, 0, crop[2] - crop[0]], [0, 1, 0], [0, 0, 1]]
)
transform_matrix = flip_matrix @ transform_matrix
rotate = rotate / 180 * np.pi
rot_matrix = np.array(
[
[np.cos(rotate), np.sin(rotate), 0],
[-np.sin(rotate), np.cos(rotate), 0],
[0, 0, 1],
]
)
rot_center = np.array([crop[2] - crop[0], crop[3] - crop[1]]) / 2
rot_matrix[:2, 2] = -rot_matrix[:2, :2] @ rot_center + rot_center
transform_matrix = rot_matrix @ transform_matrix
extend_matrix = np.eye(4)
extend_matrix[:3, :3] = transform_matrix
return img, extend_matrix
@PIPELINES.register_module()
class BBoxRotation(object):
def __call__(self, results):
angle = results["aug_config"]["rotate_3d"]
rot_cos = np.cos(angle)
rot_sin = np.sin(angle)
rot_mat = np.array(
[
[rot_cos, -rot_sin, 0, 0],
[rot_sin, rot_cos, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
)
rot_mat_inv = np.linalg.inv(rot_mat)
num_view = len(results["lidar2img"])
for view in range(num_view):
results["lidar2img"][view] = (
results["lidar2img"][view] @ rot_mat_inv
)
if "lidar2global" in results:
results["lidar2global"] = results["lidar2global"] @ rot_mat_inv
if "gt_bboxes_3d" in results:
results["gt_bboxes_3d"] = self.box_rotate(
results["gt_bboxes_3d"], angle
)
return results
@staticmethod
def box_rotate(bbox_3d, angle):
rot_cos = np.cos(angle)
rot_sin = np.sin(angle)
rot_mat_T = np.array(
[[rot_cos, rot_sin, 0], [-rot_sin, rot_cos, 0], [0, 0, 1]]
)
bbox_3d[:, :3] = bbox_3d[:, :3] @ rot_mat_T
bbox_3d[:, 6] += angle
if bbox_3d.shape[-1] > 7:
vel_dims = bbox_3d[:, 7:].shape[-1]
bbox_3d[:, 7:] = bbox_3d[:, 7:] @ rot_mat_T[:vel_dims, :vel_dims]
return bbox_3d
@PIPELINES.register_module()
class PhotoMetricDistortionMultiViewImage:
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(
self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18,
):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, results):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
imgs = results["img"]
new_imgs = []
for img in imgs:
assert img.dtype == np.float32, (
"PhotoMetricDistortion needs the input image of dtype np.float32,"
' please set "to_float32=True" in "LoadImageFromFile" pipeline'
)
# random brightness
if random.randint(2):
delta = random.uniform(
-self.brightness_delta, self.brightness_delta
)
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
if random.randint(2):
alpha = random.uniform(
self.contrast_lower, self.contrast_upper
)
img *= alpha
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if random.randint(2):
img[..., 1] *= random.uniform(
self.saturation_lower, self.saturation_upper
)
# random hue
if random.randint(2):
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if random.randint(2):
alpha = random.uniform(
self.contrast_lower, self.contrast_upper
)
img *= alpha
# randomly swap channels
if random.randint(2):
img = img[..., random.permutation(3)]
new_imgs.append(img)
results["img"] = new_imgs
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(\nbrightness_delta={self.brightness_delta},\n"
repr_str += "contrast_range="
repr_str += f"{(self.contrast_lower, self.contrast_upper)},\n"
repr_str += "saturation_range="
repr_str += f"{(self.saturation_lower, self.saturation_upper)},\n"
repr_str += f"hue_delta={self.hue_delta})"
return repr_str
import numpy as np
import mmcv
from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module()
class LoadMultiViewImageFromFiles(object):
"""Load multi channel images from a list of separate channel files.
Expects results['img_filename'] to be a list of filenames.
Args:
to_float32 (bool, optional): Whether to convert the img to float32.
Defaults to False.
color_type (str, optional): Color type of the file.
Defaults to 'unchanged'.
"""
def __init__(self, to_float32=False, color_type="unchanged"):
self.to_float32 = to_float32
self.color_type = color_type
def __call__(self, results):
"""Call function to load multi-view image from files.
Args:
results (dict): Result dict containing multi-view image filenames.
Returns:
dict: The result dict containing the multi-view image data.
Added keys and values are described below.
- filename (str): Multi-view image filenames.
- img (np.ndarray): Multi-view image arrays.
- img_shape (tuple[int]): Shape of multi-view image arrays.
- ori_shape (tuple[int]): Shape of original image arrays.
- pad_shape (tuple[int]): Shape of padded image arrays.
- scale_factor (float): Scale factor.
- img_norm_cfg (dict): Normalization configuration of images.
"""
filename = results["img_filename"]
# img is of shape (h, w, c, num_views)
img = np.stack(
[mmcv.imread(name, self.color_type) for name in filename], axis=-1
)
if self.to_float32:
img = img.astype(np.float32)
results["filename"] = filename
# unravel to list, see `DefaultFormatBundle` in formatting.py
# which will transpose each image separately and then stack into array
results["img"] = [img[..., i] for i in range(img.shape[-1])]
results["img_shape"] = img.shape
results["ori_shape"] = img.shape
# Set initial values for default meta_keys
results["pad_shape"] = img.shape
results["scale_factor"] = 1.0
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
results["img_norm_cfg"] = dict(
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False,
)
return results
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f"(to_float32={self.to_float32}, "
repr_str += f"color_type='{self.color_type}')"
return repr_str
@PIPELINES.register_module()
class LoadPointsFromFile(object):
"""Load Points From File.
Load points from file.
Args:
coord_type (str): The type of coordinates of points cloud.
Available options includes:
- 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates.
load_dim (int, optional): The dimension of the loaded points.
Defaults to 6.
use_dim (list[int], optional): Which dimensions of the points to use.
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool, optional): Whether to use shifted height.
Defaults to False.
use_color (bool, optional): Whether to use color features.
Defaults to False.
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
"""
def __init__(
self,
coord_type,
load_dim=6,
use_dim=[0, 1, 2],
shift_height=False,
use_color=False,
file_client_args=dict(backend="disk"),
):
self.shift_height = shift_height
self.use_color = use_color
if isinstance(use_dim, int):
use_dim = list(range(use_dim))
assert (
max(use_dim) < load_dim
), f"Expect all used dimensions < {load_dim}, got {use_dim}"
assert coord_type in ["CAMERA", "LIDAR", "DEPTH"]
self.coord_type = coord_type
self.load_dim = load_dim
self.use_dim = use_dim
self.file_client_args = file_client_args.copy()
self.file_client = None
def _load_points(self, pts_filename):
"""Private function to load point clouds data.
Args:
pts_filename (str): Filename of point clouds data.
Returns:
np.ndarray: An array containing point clouds data.
"""
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
try:
pts_bytes = self.file_client.get(pts_filename)
points = np.frombuffer(pts_bytes, dtype=np.float32)
except ConnectionError:
mmcv.check_file_exist(pts_filename)
if pts_filename.endswith(".npy"):
points = np.load(pts_filename)
else:
points = np.fromfile(pts_filename, dtype=np.float32)
return points
def __call__(self, results):
"""Call function to load points data from file.
Args:
results (dict): Result dict containing point clouds data.
Returns:
dict: The result dict containing the point clouds data.
Added key and value are described below.
- points (:obj:`BasePoints`): Point clouds data.
"""
pts_filename = results["pts_filename"]
points = self._load_points(pts_filename)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
attribute_dims = None
if self.shift_height:
floor_height = np.percentile(points[:, 2], 0.99)
height = points[:, 2] - floor_height
points = np.concatenate(
[points[:, :3], np.expand_dims(height, 1), points[:, 3:]], 1
)
attribute_dims = dict(height=3)
if self.use_color:
assert len(self.use_dim) >= 6
if attribute_dims is None:
attribute_dims = dict()
attribute_dims.update(
dict(
color=[
points.shape[1] - 3,
points.shape[1] - 2,
points.shape[1] - 1,
]
)
)
results["points"] = points
return results
import numpy as np
import mmcv
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import to_tensor
@PIPELINES.register_module()
class MultiScaleDepthMapGenerator(object):
def __init__(self, downsample=1, max_depth=60):
if not isinstance(downsample, (list, tuple)):
downsample = [downsample]
self.downsample = downsample
self.max_depth = max_depth
def __call__(self, input_dict):
points = input_dict["points"][..., :3, None]
gt_depth = []
for i, lidar2img in enumerate(input_dict["lidar2img"]):
H, W = input_dict["img_shape"][i][:2]
pts_2d = (
np.squeeze(lidar2img[:3, :3] @ points, axis=-1)
+ lidar2img[:3, 3]
)
pts_2d[:, :2] /= pts_2d[:, 2:3]
U = np.round(pts_2d[:, 0]).astype(np.int32)
V = np.round(pts_2d[:, 1]).astype(np.int32)
depths = pts_2d[:, 2]
mask = np.logical_and.reduce(
[
V >= 0,
V < H,
U >= 0,
U < W,
depths >= 0.1,
# depths <= self.max_depth,
]
)
V, U, depths = V[mask], U[mask], depths[mask]
sort_idx = np.argsort(depths)[::-1]
V, U, depths = V[sort_idx], U[sort_idx], depths[sort_idx]
depths = np.clip(depths, 0.1, self.max_depth)
for j, downsample in enumerate(self.downsample):
if len(gt_depth) < j + 1:
gt_depth.append([])
h, w = (int(H / downsample), int(W / downsample))
u = np.floor(U / downsample).astype(np.int32)
v = np.floor(V / downsample).astype(np.int32)
depth_map = np.ones([h, w], dtype=np.float32) * -1
depth_map[v, u] = depths
gt_depth[j].append(depth_map)
input_dict["gt_depth"] = [np.stack(x) for x in gt_depth]
return input_dict
@PIPELINES.register_module()
class NuScenesSparse4DAdaptor(object):
def __init(self):
pass
def __call__(self, input_dict):
input_dict["projection_mat"] = np.float32(
np.stack(input_dict["lidar2img"])
)
input_dict["image_wh"] = np.ascontiguousarray(
np.array(input_dict["img_shape"], dtype=np.float32)[:, :2][:, ::-1]
)
input_dict["T_global_inv"] = np.linalg.inv(input_dict["lidar2global"])
input_dict["T_global"] = input_dict["lidar2global"]
if "cam_intrinsic" in input_dict:
input_dict["cam_intrinsic"] = np.float32(
np.stack(input_dict["cam_intrinsic"])
)
input_dict["focal"] = input_dict["cam_intrinsic"][..., 0, 0]
if "instance_inds" in input_dict:
input_dict["instance_id"] = input_dict["instance_inds"]
if "gt_bboxes_3d" in input_dict:
input_dict["gt_bboxes_3d"][:, 6] = self.limit_period(
input_dict["gt_bboxes_3d"][:, 6], offset=0.5, period=2 * np.pi
)
input_dict["gt_bboxes_3d"] = DC(
to_tensor(input_dict["gt_bboxes_3d"]).float()
)
if "gt_labels_3d" in input_dict:
input_dict["gt_labels_3d"] = DC(
to_tensor(input_dict["gt_labels_3d"]).long()
)
imgs = [img.transpose(2, 0, 1) for img in input_dict["img"]]
imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
input_dict["img"] = DC(to_tensor(imgs), stack=True)
for key in [
'gt_map_labels',
'gt_map_pts',
'gt_agent_fut_trajs',
'gt_agent_fut_masks',
]:
if key not in input_dict:
continue
input_dict[key] = DC(to_tensor(input_dict[key]), stack=False, cpu_only=False)
for key in [
'gt_ego_fut_trajs',
'gt_ego_fut_masks',
'gt_ego_fut_cmd',
'ego_status',
]:
if key not in input_dict:
continue
input_dict[key] = DC(to_tensor(input_dict[key]), stack=True, cpu_only=False, pad_dims=None)
return input_dict
def limit_period(
self, val: np.ndarray, offset: float = 0.5, period: float = np.pi
) -> np.ndarray:
limited_val = val - np.floor(val / period + offset) * period
return limited_val
@PIPELINES.register_module()
class InstanceNameFilter(object):
"""Filter GT objects by their names.
Args:
classes (list[str]): List of class names to be kept for training.
"""
def __init__(self, classes):
self.classes = classes
self.labels = list(range(len(self.classes)))
def __call__(self, input_dict):
"""Call function to filter objects by their names.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \
keys are updated in the result dict.
"""
gt_labels_3d = input_dict["gt_labels_3d"]
gt_bboxes_mask = np.array(
[n in self.labels for n in gt_labels_3d], dtype=np.bool_
)
input_dict["gt_bboxes_3d"] = input_dict["gt_bboxes_3d"][gt_bboxes_mask]
input_dict["gt_labels_3d"] = input_dict["gt_labels_3d"][gt_bboxes_mask]
if "instance_inds" in input_dict:
input_dict["instance_inds"] = input_dict["instance_inds"][gt_bboxes_mask]
if "gt_agent_fut_trajs" in input_dict:
input_dict["gt_agent_fut_trajs"] = input_dict["gt_agent_fut_trajs"][gt_bboxes_mask]
input_dict["gt_agent_fut_masks"] = input_dict["gt_agent_fut_masks"][gt_bboxes_mask]
return input_dict
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f"(classes={self.classes})"
return repr_str
@PIPELINES.register_module()
class CircleObjectRangeFilter(object):
def __init__(
self, class_dist_thred=[52.5] * 5 + [31.5] + [42] * 3 + [31.5]
):
self.class_dist_thred = class_dist_thred
def __call__(self, input_dict):
gt_bboxes_3d = input_dict["gt_bboxes_3d"]
gt_labels_3d = input_dict["gt_labels_3d"]
dist = np.sqrt(
np.sum(gt_bboxes_3d[:, :2] ** 2, axis=-1)
)
mask = np.array([False] * len(dist))
for label_idx, dist_thred in enumerate(self.class_dist_thred):
mask = np.logical_or(
mask,
np.logical_and(gt_labels_3d == label_idx, dist <= dist_thred),
)
gt_bboxes_3d = gt_bboxes_3d[mask]
gt_labels_3d = gt_labels_3d[mask]
input_dict["gt_bboxes_3d"] = gt_bboxes_3d
input_dict["gt_labels_3d"] = gt_labels_3d
if "instance_inds" in input_dict:
input_dict["instance_inds"] = input_dict["instance_inds"][mask]
if "gt_agent_fut_trajs" in input_dict:
input_dict["gt_agent_fut_trajs"] = input_dict["gt_agent_fut_trajs"][mask]
input_dict["gt_agent_fut_masks"] = input_dict["gt_agent_fut_masks"][mask]
return input_dict
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f"(class_dist_thred={self.class_dist_thred})"
return repr_str
@PIPELINES.register_module()
class NormalizeMultiviewImage(object):
"""Normalize the image.
Added key is "img_norm_cfg".
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
"""
def __init__(self, mean, std, to_rgb=True):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def __call__(self, results):
"""Call function to normalize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Normalized results, 'img_norm_cfg' key is added into
result dict.
"""
results["img"] = [
mmcv.imnormalize(img, self.mean, self.std, self.to_rgb)
for img in results["img"]
]
results["img_norm_cfg"] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb
)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})"
return repr_str
from typing import List, Tuple, Union, Dict
import numpy as np
from shapely.geometry import LineString
from numpy.typing import NDArray
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module(force=True)
class VectorizeMap(object):
"""Generate vectoized map and put into `semantic_mask` key.
Concretely, shapely geometry objects are converted into sample points (ndarray).
We use args `sample_num`, `sample_dist`, `simplify` to specify sampling method.
Args:
roi_size (tuple or list): bev range .
normalize (bool): whether to normalize points to range (0, 1).
coords_dim (int): dimension of point coordinates.
simplify (bool): whether to use simpily function. If true, `sample_num` \
and `sample_dist` will be ignored.
sample_num (int): number of points to interpolate from a polyline. Set to -1 to ignore.
sample_dist (float): interpolate distance. Set to -1 to ignore.
"""
def __init__(self,
roi_size: Union[Tuple, List],
normalize: bool,
coords_dim: int=2,
simplify: bool=False,
sample_num: int=-1,
sample_dist: float=-1,
permute: bool=False
):
self.coords_dim = coords_dim
self.sample_num = sample_num
self.sample_dist = sample_dist
self.roi_size = np.array(roi_size)
self.normalize = normalize
self.simplify = simplify
self.permute = permute
if sample_dist > 0:
assert sample_num < 0 and not simplify
self.sample_fn = self.interp_fixed_dist
elif sample_num > 0:
assert sample_dist < 0 and not simplify
self.sample_fn = self.interp_fixed_num
else:
assert simplify
def interp_fixed_num(self, line: LineString) -> NDArray:
''' Interpolate a line to fixed number of points.
Args:
line (LineString): line
Returns:
points (array): interpolated points, shape (N, 2)
'''
distances = np.linspace(0, line.length, self.sample_num)
sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze()
return sampled_points
def interp_fixed_dist(self, line: LineString) -> NDArray:
''' Interpolate a line at fixed interval.
Args:
line (LineString): line
Returns:
points (array): interpolated points, shape (N, 2)
'''
distances = list(np.arange(self.sample_dist, line.length, self.sample_dist))
# make sure to sample at least two points when sample_dist > line.length
distances = [0,] + distances + [line.length,]
sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze()
return sampled_points
def get_vectorized_lines(self, map_geoms: Dict) -> Dict:
''' Vectorize map elements. Iterate over the input dict and apply the
specified sample funcion.
Args:
line (LineString): line
Returns:
vectors (array): dict of vectorized map elements.
'''
vectors = {}
for label, geom_list in map_geoms.items():
vectors[label] = []
for geom in geom_list:
if geom.geom_type == 'LineString':
if self.simplify:
line = geom.simplify(0.2, preserve_topology=True)
line = np.array(line.coords)
else:
line = self.sample_fn(geom)
line = line[:, :self.coords_dim]
if self.normalize:
line = self.normalize_line(line)
if self.permute:
line = self.permute_line(line)
vectors[label].append(line)
elif geom.geom_type == 'Polygon':
# polygon objects will not be vectorized
continue
else:
raise ValueError('map geoms must be either LineString or Polygon!')
return vectors
def normalize_line(self, line: NDArray) -> NDArray:
''' Convert points to range (0, 1).
Args:
line (LineString): line
Returns:
normalized (array): normalized points.
'''
origin = -np.array([self.roi_size[0]/2, self.roi_size[1]/2])
line[:, :2] = line[:, :2] - origin
# transform from range [0, 1] to (0, 1)
eps = 1e-5
line[:, :2] = line[:, :2] / (self.roi_size + eps)
return line
def permute_line(self, line: np.ndarray, padding=1e5):
'''
(num_pts, 2) -> (num_permute, num_pts, 2)
where num_permute = 2 * (num_pts - 1)
'''
is_closed = np.allclose(line[0], line[-1], atol=1e-3)
num_points = len(line)
permute_num = num_points - 1
permute_lines_list = []
if is_closed:
pts_to_permute = line[:-1, :] # throw away replicate start end pts
for shift_i in range(permute_num):
permute_lines_list.append(np.roll(pts_to_permute, shift_i, axis=0))
flip_pts_to_permute = np.flip(pts_to_permute, axis=0)
for shift_i in range(permute_num):
permute_lines_list.append(np.roll(flip_pts_to_permute, shift_i, axis=0))
else:
permute_lines_list.append(line)
permute_lines_list.append(np.flip(line, axis=0))
permute_lines_array = np.stack(permute_lines_list, axis=0)
if is_closed:
tmp = np.zeros((permute_num * 2, num_points, self.coords_dim))
tmp[:, :-1, :] = permute_lines_array
tmp[:, -1, :] = permute_lines_array[:, 0, :] # add replicate start end pts
permute_lines_array = tmp
else:
# padding
padding = np.full([permute_num * 2 - 2, num_points, self.coords_dim], padding)
permute_lines_array = np.concatenate((permute_lines_array, padding), axis=0)
return permute_lines_array
def __call__(self, input_dict):
if "map_geoms" not in input_dict:
return input_dict
map_geoms = input_dict['map_geoms']
vectors = self.get_vectorized_lines(map_geoms)
if self.permute:
gt_map_labels, gt_map_pts = [], []
for label, vecs in vectors.items():
for vec in vecs:
gt_map_labels.append(label)
gt_map_pts.append(vec)
input_dict['gt_map_labels'] = np.array(gt_map_labels, dtype=np.int64)
input_dict['gt_map_pts'] = np.array(gt_map_pts, dtype=np.float32).reshape(-1, 2 * (self.sample_num - 1), self.sample_num, self.coords_dim)
else:
input_dict['vectors'] = DC(vectors, stack=False, cpu_only=True)
return input_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(simplify={self.simplify}, '
repr_str += f'sample_num={self.sample_num}), '
repr_str += f'sample_dist={self.sample_dist}), '
repr_str += f'roi_size={self.roi_size})'
repr_str += f'normalize={self.normalize})'
repr_str += f'coords_dim={self.coords_dim})'
return repr_str
\ No newline at end of file
from .group_sampler import DistributedGroupSampler
from .distributed_sampler import DistributedSampler
from .sampler import SAMPLER, build_sampler
from .group_in_batch_sampler import (
GroupInBatchSampler,
)
import math
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler
from .sampler import SAMPLER
import pdb
import sys
class ForkedPdb(pdb.Pdb):
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
def set_trace():
ForkedPdb().set_trace(sys._getframe().f_back)
@SAMPLER.register_module()
class DistributedSampler(_DistributedSampler):
def __init__(
self, dataset=None, num_replicas=None, rank=None, shuffle=True, seed=0
):
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle
)
# for the compatibility from PyTorch 1.3+
self.seed = seed if seed is not None else 0
def __iter__(self):
# deterministically shuffle based on epoch
assert not self.shuffle
if "data_infos" in dir(self.dataset):
timestamps = [
x["timestamp"] / 1e6 for x in self.dataset.data_infos
]
vehicle_idx = [
x["lidar_path"].split("/")[-1][:4]
if "lidar_path" in x
else None
for x in self.dataset.data_infos
]
else:
timestamps = [
x["timestamp"] / 1e6
for x in self.dataset.datasets[0].data_infos
] * len(self.dataset.datasets)
vehicle_idx = [
x["lidar_path"].split("/")[-1][:4]
if "lidar_path" in x
else None
for x in self.dataset.datasets[0].data_infos
] * len(self.dataset.datasets)
sequence_splits = []
for i in range(len(timestamps)):
if i == 0 or (
abs(timestamps[i] - timestamps[i - 1]) > 4
or vehicle_idx[i] != vehicle_idx[i - 1]
):
sequence_splits.append([i])
else:
sequence_splits[-1].append(i)
indices = []
perfix_sum = 0
split_length = len(self.dataset) // self.num_replicas
for i in range(len(sequence_splits)):
if perfix_sum >= (self.rank + 1) * split_length:
break
elif perfix_sum >= self.rank * split_length:
indices.extend(sequence_splits[i])
perfix_sum += len(sequence_splits[i])
self.num_samples = len(indices)
return iter(indices)
# https://github.com/Divadi/SOLOFusion/blob/main/mmdet3d/datasets/samplers/infinite_group_each_sample_in_batch_sampler.py
import itertools
import copy
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
from torch.utils.data.sampler import Sampler
# https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/core/utils/dist_utils.py#L157
def sync_random_seed(seed=None, device="cuda"):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
This method is generally used in `DistributedSampler`,
because the seed should be identical across all processes
in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)
rank, world_size = get_dist_info()
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
class GroupInBatchSampler(Sampler):
"""
Pardon this horrendous name. Basically, we want every sample to be from its own group.
If batch size is 4 and # of GPUs is 8, each sample of these 32 should be operating on
its own group.
Shuffling is only done for group order, not done within groups.
"""
def __init__(
self,
dataset,
batch_size=1,
world_size=None,
rank=None,
seed=0,
skip_prob=0.,
sequence_flip_prob=0.,
):
_rank, _world_size = get_dist_info()
if world_size is None:
world_size = _world_size
if rank is None:
rank = _rank
self.dataset = dataset
self.batch_size = batch_size
self.world_size = world_size
self.rank = rank
self.seed = sync_random_seed(seed)
self.size = len(self.dataset)
assert hasattr(self.dataset, "flag")
self.flag = self.dataset.flag
self.group_sizes = np.bincount(self.flag)
self.groups_num = len(self.group_sizes)
self.global_batch_size = batch_size * world_size
assert self.groups_num >= self.global_batch_size
# Now, for efficiency, make a dict group_idx: List[dataset sample_idxs]
self.group_idx_to_sample_idxs = {
group_idx: np.where(self.flag == group_idx)[0].tolist()
for group_idx in range(self.groups_num)
}
# Get a generator per sample idx. Considering samples over all
# GPUs, each sample position has its own generator
self.group_indices_per_global_sample_idx = [
self._group_indices_per_global_sample_idx(
self.rank * self.batch_size + local_sample_idx
)
for local_sample_idx in range(self.batch_size)
]
# Keep track of a buffer of dataset sample idxs for each local sample idx
self.buffer_per_local_sample = [[] for _ in range(self.batch_size)]
self.aug_per_local_sample = [None for _ in range(self.batch_size)]
self.skip_prob = skip_prob
self.sequence_flip_prob = sequence_flip_prob
def _infinite_group_indices(self):
g = torch.Generator()
g.manual_seed(self.seed)
while True:
yield from torch.randperm(self.groups_num, generator=g).tolist()
def _group_indices_per_global_sample_idx(self, global_sample_idx):
yield from itertools.islice(
self._infinite_group_indices(),
global_sample_idx,
None,
self.global_batch_size,
)
def __iter__(self):
while True:
curr_batch = []
for local_sample_idx in range(self.batch_size):
skip = (
np.random.uniform() < self.skip_prob
and len(self.buffer_per_local_sample[local_sample_idx]) > 1
)
if len(self.buffer_per_local_sample[local_sample_idx]) == 0:
# Finished current group, refill with next group
# skip = False
new_group_idx = next(
self.group_indices_per_global_sample_idx[
local_sample_idx
]
)
self.buffer_per_local_sample[
local_sample_idx
] = copy.deepcopy(
self.group_idx_to_sample_idxs[new_group_idx]
)
if np.random.uniform() < self.sequence_flip_prob:
self.buffer_per_local_sample[
local_sample_idx
] = self.buffer_per_local_sample[local_sample_idx][
::-1
]
if self.dataset.keep_consistent_seq_aug:
self.aug_per_local_sample[
local_sample_idx
] = self.dataset.get_augmentation()
if not self.dataset.keep_consistent_seq_aug:
self.aug_per_local_sample[
local_sample_idx
] = self.dataset.get_augmentation()
if skip:
self.buffer_per_local_sample[local_sample_idx].pop(0)
curr_batch.append(
dict(
idx=self.buffer_per_local_sample[local_sample_idx].pop(
0
),
aug_config=self.aug_per_local_sample[local_sample_idx],
)
)
yield curr_batch
def __len__(self):
"""Length of base dataset."""
return self.size
def set_epoch(self, epoch):
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import math
import numpy as np
import torch
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler
from .sampler import SAMPLER
import random
from IPython import embed
@SAMPLER.register_module()
class DistributedGroupSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
seed (int, optional): random seed used to shuffle the sampler if
``shuffle=True``. This number should be identical across all
processes in the distributed group. Default: 0.
"""
def __init__(
self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, seed=0
):
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = _num_replicas
if rank is None:
rank = _rank
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.seed = seed if seed is not None else 0
assert hasattr(self.dataset, "flag")
self.flag = self.dataset.flag
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, j in enumerate(self.group_sizes):
self.num_samples += (
int(
math.ceil(
self.group_sizes[i]
* 1.0
/ self.samples_per_gpu
/ self.num_replicas
)
)
* self.samples_per_gpu
)
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
indices = []
for i, size in enumerate(self.group_sizes):
if size > 0:
indice = np.where(self.flag == i)[0]
assert len(indice) == size
# add .numpy() to avoid bug when selecting indice in parrots.
# TODO: check whether torch.randperm() can be replaced by
# numpy.random.permutation().
indice = indice[
list(torch.randperm(int(size), generator=g).numpy())
].tolist()
extra = int(
math.ceil(
size * 1.0 / self.samples_per_gpu / self.num_replicas
)
) * self.samples_per_gpu * self.num_replicas - len(indice)
# pad indice
tmp = indice.copy()
for _ in range(extra // size):
indice.extend(tmp)
indice.extend(tmp[: extra % size])
indices.extend(indice)
assert len(indices) == self.total_size
indices = [
indices[j]
for i in list(
torch.randperm(
len(indices) // self.samples_per_gpu, generator=g
)
)
for j in range(
i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu
)
]
# subsample
offset = self.num_samples * self.rank
indices = indices[offset : offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
from mmcv.utils.registry import Registry, build_from_cfg
SAMPLER = Registry("sampler")
def build_sampler(cfg, default_args):
return build_from_cfg(cfg, SAMPLER, default_args)
import copy
import cv2
import numpy as np
import torch
from projects.mmdet3d_plugin.core.box3d import *
def box3d_to_corners(box3d):
if isinstance(box3d, torch.Tensor):
box3d = box3d.detach().cpu().numpy()
corners_norm = np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)
corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
# use relative origin [0.5, 0.5, 0]
corners_norm = corners_norm - np.array([0.5, 0.5, 0.5])
corners = box3d[:, None, [W, L, H]] * corners_norm.reshape([1, 8, 3])
# rotate around z axis
rot_cos = np.cos(box3d[:, YAW])
rot_sin = np.sin(box3d[:, YAW])
rot_mat = np.tile(np.eye(3)[None], (box3d.shape[0], 1, 1))
rot_mat[:, 0, 0] = rot_cos
rot_mat[:, 0, 1] = -rot_sin
rot_mat[:, 1, 0] = rot_sin
rot_mat[:, 1, 1] = rot_cos
corners = (rot_mat[:, None] @ corners[..., None]).squeeze(axis=-1)
corners += box3d[:, None, :3]
return corners
def plot_rect3d_on_img(
img, num_rects, rect_corners, color=(0, 255, 0), thickness=1
):
"""Plot the boundary lines of 3D rectangular on 2D images.
Args:
img (numpy.array): The numpy array of image.
num_rects (int): Number of 3D rectangulars.
rect_corners (numpy.array): Coordinates of the corners of 3D
rectangulars. Should be in the shape of [num_rect, 8, 2].
color (tuple[int], optional): The color to draw bboxes.
Default: (0, 255, 0).
thickness (int, optional): The thickness of bboxes. Default: 1.
"""
line_indices = (
(0, 1),
(0, 3),
(0, 4),
(1, 2),
(1, 5),
(3, 2),
(3, 7),
(4, 5),
(4, 7),
(2, 6),
(5, 6),
(6, 7),
)
h, w = img.shape[:2]
for i in range(num_rects):
corners = np.clip(rect_corners[i], -1e4, 1e5).astype(np.int32)
for start, end in line_indices:
if (
(corners[start, 1] >= h or corners[start, 1] < 0)
or (corners[start, 0] >= w or corners[start, 0] < 0)
) and (
(corners[end, 1] >= h or corners[end, 1] < 0)
or (corners[end, 0] >= w or corners[end, 0] < 0)
):
continue
if isinstance(color[0], int):
cv2.line(
img,
(corners[start, 0], corners[start, 1]),
(corners[end, 0], corners[end, 1]),
color,
thickness,
cv2.LINE_AA,
)
else:
cv2.line(
img,
(corners[start, 0], corners[start, 1]),
(corners[end, 0], corners[end, 1]),
color[i],
thickness,
cv2.LINE_AA,
)
return img.astype(np.uint8)
def draw_lidar_bbox3d_on_img(
bboxes3d, raw_img, lidar2img_rt, img_metas=None, color=(0, 255, 0), thickness=1
):
"""Project the 3D bbox on 2D plane and draw on input image.
Args:
bboxes3d (:obj:`LiDARInstance3DBoxes`):
3d bbox in lidar coordinate system to visualize.
raw_img (numpy.array): The numpy array of image.
lidar2img_rt (numpy.array, shape=[4, 4]): The projection matrix
according to the camera intrinsic parameters.
img_metas (dict): Useless here.
color (tuple[int], optional): The color to draw bboxes.
Default: (0, 255, 0).
thickness (int, optional): The thickness of bboxes. Default: 1.
"""
img = raw_img.copy()
# corners_3d = bboxes3d.corners
corners_3d = box3d_to_corners(bboxes3d)
num_bbox = corners_3d.shape[0]
pts_4d = np.concatenate(
[corners_3d.reshape(-1, 3), np.ones((num_bbox * 8, 1))], axis=-1
)
lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4)
if isinstance(lidar2img_rt, torch.Tensor):
lidar2img_rt = lidar2img_rt.cpu().numpy()
pts_2d = pts_4d @ lidar2img_rt.T
pts_2d[:, 2] = np.clip(pts_2d[:, 2], a_min=1e-5, a_max=1e5)
pts_2d[:, 0] /= pts_2d[:, 2]
pts_2d[:, 1] /= pts_2d[:, 2]
imgfov_pts_2d = pts_2d[..., :2].reshape(num_bbox, 8, 2)
return plot_rect3d_on_img(img, num_bbox, imgfov_pts_2d, color, thickness)
def draw_points_on_img(points, img, lidar2img_rt, color=(0, 255, 0), circle=4):
img = img.copy()
N = points.shape[0]
points = points.cpu().numpy()
lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4)
if isinstance(lidar2img_rt, torch.Tensor):
lidar2img_rt = lidar2img_rt.cpu().numpy()
pts_2d = (
np.sum(points[:, :, None] * lidar2img_rt[:3, :3], axis=-1)
+ lidar2img_rt[:3, 3]
)
pts_2d[..., 2] = np.clip(pts_2d[..., 2], a_min=1e-5, a_max=1e5)
pts_2d = pts_2d[..., :2] / pts_2d[..., 2:3]
pts_2d = np.clip(pts_2d, -1e4, 1e4).astype(np.int32)
for i in range(N):
for point in pts_2d[i]:
if isinstance(color[0], int):
color_tmp = color
else:
color_tmp = color[i]
cv2.circle(img, point.tolist(), circle, color_tmp, thickness=-1)
return img.astype(np.uint8)
def draw_lidar_bbox3d_on_bev(
bboxes_3d, bev_size, bev_range=115, color=(255, 0, 0), thickness=3):
if isinstance(bev_size, (list, tuple)):
bev_h, bev_w = bev_size
else:
bev_h, bev_w = bev_size, bev_size
bev = np.zeros([bev_h, bev_w, 3])
marking_color = (127, 127, 127)
bev_resolution = bev_range / bev_h
for cir in range(int(bev_range / 2 / 10)):
cv2.circle(
bev,
(int(bev_h / 2), int(bev_w / 2)),
int((cir + 1) * 10 / bev_resolution),
marking_color,
thickness=thickness,
)
cv2.line(
bev,
(0, int(bev_h / 2)),
(bev_w, int(bev_h / 2)),
marking_color,
)
cv2.line(
bev,
(int(bev_w / 2), 0),
(int(bev_w / 2), bev_h),
marking_color,
)
if len(bboxes_3d) != 0:
bev_corners = box3d_to_corners(bboxes_3d)[:, [0, 3, 4, 7]][
..., [0, 1]
]
xs = bev_corners[..., 0] / bev_resolution + bev_w / 2
ys = -bev_corners[..., 1] / bev_resolution + bev_h / 2
for obj_idx, (x, y) in enumerate(zip(xs, ys)):
for p1, p2 in ((0, 1), (0, 2), (1, 3), (2, 3)):
if isinstance(color[0], (list, tuple)):
tmp = color[obj_idx]
else:
tmp = color
cv2.line(
bev,
(int(x[p1]), int(y[p1])),
(int(x[p2]), int(y[p2])),
tmp,
thickness=thickness,
)
return bev.astype(np.uint8)
def draw_lidar_bbox3d(bboxes_3d, imgs, lidar2imgs, color=(255, 0, 0)):
vis_imgs = []
for i, (img, lidar2img) in enumerate(zip(imgs, lidar2imgs)):
vis_imgs.append(
draw_lidar_bbox3d_on_img(bboxes_3d, img, lidar2img, color=color)
)
num_imgs = len(vis_imgs)
if num_imgs < 4 or num_imgs % 2 != 0:
vis_imgs = np.concatenate(vis_imgs, axis=1)
else:
vis_imgs = np.concatenate([
np.concatenate(vis_imgs[:num_imgs//2], axis=1),
np.concatenate(vis_imgs[num_imgs//2:], axis=1)
], axis=0)
bev = draw_lidar_bbox3d_on_bev(bboxes_3d, vis_imgs.shape[0], color=color)
vis_imgs = np.concatenate([bev, vis_imgs], axis=1)
return vis_imgs
from .sparsedrive import SparseDrive
from .sparsedrive_head import SparseDriveHead
from .blocks import (
DeformableFeatureAggregation,
DenseDepthNet,
AsymmetricFFN,
)
from .instance_bank import InstanceBank
from .detection3d import (
SparseBox3DDecoder,
SparseBox3DTarget,
SparseBox3DRefinementModule,
SparseBox3DKeyPointsGenerator,
SparseBox3DEncoder,
)
from .map import *
from .motion import *
__all__ = [
"SparseDrive",
"SparseDriveHead",
"DeformableFeatureAggregation",
"DenseDepthNet",
"AsymmetricFFN",
"InstanceBank",
"SparseBox3DDecoder",
"SparseBox3DTarget",
"SparseBox3DRefinementModule",
"SparseBox3DKeyPointsGenerator",
"SparseBox3DEncoder",
]
import warnings
import math
import torch
import torch.nn as nn
from torch.nn.functional import linear
from torch.nn.init import xavier_uniform_, constant_
from mmcv.utils import deprecated_api_warning
from mmcv.runner import auto_fp16
from mmcv.runner.base_module import BaseModule
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.registry import ATTENTION
import torch.utils.checkpoint as cp
from einops import rearrange
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
print('Use flash_attn_unpadded_kvpacked_func')
except:
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func
print('Use flash_attn_varlen_kvpacked_func')
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
def _in_projection_packed(q, k, v, w, b = None):
w_q, w_k, w_v = w.chunk(3)
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.fp16_enabled = True
@auto_fp16(apply_to=('q', 'kv'), out_fp32=True)
def forward(self, q, kv,
causal=False,
key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, T, H, D)
kv: The tensor containing the key, and value. (B, S, 2, H, D)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert q.dtype in [torch.float16, torch.bfloat16] and kv.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda and kv.is_cuda
assert q.shape[0] == kv.shape[0] and q.shape[-2] == kv.shape[-2] and q.shape[-1] == kv.shape[-1]
batch_size = q.shape[0]
seqlen_q, seqlen_k = q.shape[1], kv.shape[1]
if key_padding_mask is None:
q, kv = rearrange(q, 'b s ... -> (b s) ...'), rearrange(kv, 'b s ... -> (b s) ...')
max_sq, max_sk = seqlen_q, seqlen_k
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=kv.device)
output = flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens_q, cu_seqlens_k, max_sq, max_sk,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
nheads = kv.shape[-2]
q = rearrange(q, 'b s ... -> (b s) ...')
max_sq = seqlen_q
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
x = rearrange(kv, 'b s two h d -> b s (two h d)')
x_unpad, indices, cu_seqlens_k, max_sk = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, 'nnz (two h d) -> nnz two h d', two=2, h=nheads)
output_unpad = flash_attn_unpadded_kvpacked_func(
q, x_unpad, cu_seqlens_q, cu_seqlens_k, max_sq, max_sk,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output_unpad, '(b s) ... -> b s ...', b=batch_size)
return output, None
class FlashMHA(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
causal=False, device=None, dtype=None, **kwargs) -> None:
assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.bias = bias
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim)))
if bias:
self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self._reset_parameters()
def _reset_parameters(self) -> None:
xavier_uniform_(self.in_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
def forward(self, q, k, v, key_padding_mask=None):
"""x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
key_padding_mask: bool tensor of shape (batch, seqlen)
"""
q, k, v = _in_projection_packed(q, k, v, self.in_proj_weight, self.in_proj_bias)
q = rearrange(q, 'b s (h d) -> b s h d', h=self.num_heads)
k = rearrange(k, 'b s (h d) -> b s h d', h=self.num_heads)
v = rearrange(v, 'b s (h d) -> b s h d', h=self.num_heads)
kv = torch.stack([k, v], dim=2)
context, attn_weights = self.inner_attn(q, kv, key_padding_mask=key_padding_mask, causal=self.causal)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights
@ATTENTION.register_module()
class MultiheadFlashAttention(BaseModule):
"""A wrapper for ``torch.nn.MultiheadAttention``.
This module implements MultiheadAttention with identity connection,
and positional encoding is also passed as input.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (agent:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (agent:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): When it is True, Key, Query and Value are shape of
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
Default to False.
"""
def __init__(self,
embed_dims,
num_heads,
attn_drop=0.,
proj_drop=0.,
dropout_layer=dict(type='Dropout', drop_prob=0.),
init_cfg=None,
batch_first=True,
**kwargs):
super(MultiheadFlashAttention, self).__init__(init_cfg)
if 'dropout' in kwargs:
warnings.warn(
'The arguments `dropout` in MultiheadAttention '
'has been deprecated, now you can separately '
'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) ', DeprecationWarning)
attn_drop = kwargs['dropout']
dropout_layer['drop_prob'] = kwargs.pop('dropout')
self.embed_dims = embed_dims
self.num_heads = num_heads
self.batch_first = True
self.attn = FlashMHA(
embed_dim=embed_dims,
num_heads=num_heads,
attention_dropout=attn_drop,
dtype=torch.float16,
device='cuda',
**kwargs
)
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiheadAttention')
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `MultiheadAttention`.
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (Tensor): The input query with shape [num_queries, bs,
embed_dims] if self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
If None, the ``query`` will be used. Defaults to None.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Defaults to None.
If None, the `key` will be used.
identity (Tensor): This tensor, with the same shape as x,
will be used for the identity link.
If None, `x` will be used. Defaults to None.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. If not None, it will
be added to `x` before forward function. Defaults to None.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`. Defaults to None. If not None, it will
be added to `key` before forward function. If None, and
`query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`. Defaults to None.
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
num_keys]. Same in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
Defaults to None.
Returns:
Tensor: forwarded results with shape
[num_queries, bs, embed_dims]
if self.batch_first is False, else
[bs, num_queries embed_dims].
"""
assert attn_mask is None, 'attn mask not supported now.'
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
warnings.warn(f'position encoding of key is'
f'missing in {self.__class__.__name__}.')
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
# The dataflow('key', 'query', 'value') of ``FlashAttention`` is (batch, num_query, embed_dims).
if not self.batch_first:
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
out = self.attn(
q=query,
k=key,
v=value,
key_padding_mask=key_padding_mask)[0]
if not self.batch_first:
out = out.transpose(0, 1)
return identity + self.dropout_layer(self.proj_drop(out))
def gen_sineembed_for_position(pos_tensor, hidden_dim=256):
"""Mostly copy-paste from https://github.com/IDEA-opensource/DAB-DETR/
"""
half_hidden_dim = hidden_dim // 2
scale = 2 * math.pi
dim_t = torch.arange(half_hidden_dim, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / half_hidden_dim)
x_embed = pos_tensor[..., 0] * scale
y_embed = pos_tensor[..., 1] * scale
pos_x = x_embed[..., None] / dim_t
pos_y = y_embed[..., None] / dim_t
pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
pos = torch.cat((pos_y, pos_x), dim=-1)
return pos
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment