Commit f3b13cad authored by yeshenglong1's avatar yeshenglong1
Browse files

UpDate README.md

parent 0797920d
import mmcv import mmcv
import numpy as np import numpy as np
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module(force=True) @PIPELINES.register_module(force=True)
class LoadMultiViewImagesFromFiles(object): class LoadMultiViewImagesFromFiles(object):
"""Load multi channel images from a list of separate channel files. """Load multi channel images from a list of separate channel files.
Expects results['img_filename'] to be a list of filenames. Expects results['img_filename'] to be a list of filenames.
Args: Args:
to_float32 (bool): Whether to convert the img to float32. to_float32 (bool): Whether to convert the img to float32.
Defaults to False. Defaults to False.
color_type (str): Color type of the file. Defaults to 'unchanged'. color_type (str): Color type of the file. Defaults to 'unchanged'.
""" """
def __init__(self, to_float32=False, color_type='unchanged'): def __init__(self, to_float32=False, color_type='unchanged'):
self.to_float32 = to_float32 self.to_float32 = to_float32
self.color_type = color_type self.color_type = color_type
def __call__(self, results): def __call__(self, results):
"""Call function to load multi-view image from files. """Call function to load multi-view image from files.
Args: Args:
results (dict): Result dict containing multi-view image filenames. results (dict): Result dict containing multi-view image filenames.
Returns: Returns:
dict: The result dict containing the multi-view image data. \ dict: The result dict containing the multi-view image data. \
Added keys and values are described below. Added keys and values are described below.
- filename (str): Multi-view image filenames. - filename (str): Multi-view image filenames.
- img (np.ndarray): Multi-view image arrays. - img (np.ndarray): Multi-view image arrays.
- img_shape (tuple[int]): Shape of multi-view image arrays. - img_shape (tuple[int]): Shape of multi-view image arrays.
- ori_shape (tuple[int]): Shape of original image arrays. - ori_shape (tuple[int]): Shape of original image arrays.
- pad_shape (tuple[int]): Shape of padded image arrays. - pad_shape (tuple[int]): Shape of padded image arrays.
- scale_factor (float): Scale factor. - scale_factor (float): Scale factor.
- img_norm_cfg (dict): Normalization configuration of images. - img_norm_cfg (dict): Normalization configuration of images.
""" """
filename = results['img_filenames'] filename = results['img_filenames']
img = [mmcv.imread(name, self.color_type) for name in filename] img = [mmcv.imread(name, self.color_type) for name in filename]
if self.to_float32: if self.to_float32:
img = [i.astype(np.float32) for i in img] img = [i.astype(np.float32) for i in img]
results['img'] = img results['img'] = img
results['img_shape'] = [i.shape for i in img] results['img_shape'] = [i.shape for i in img]
results['ori_shape'] = [i.shape for i in img] results['ori_shape'] = [i.shape for i in img]
# Set initial values for default meta_keys # Set initial values for default meta_keys
results['pad_shape'] = [i.shape for i in img] results['pad_shape'] = [i.shape for i in img]
# results['scale_factor'] = 1.0 # results['scale_factor'] = 1.0
num_channels = 1 if len(img[0].shape) < 3 else img[0].shape[2] num_channels = 1 if len(img[0].shape) < 3 else img[0].shape[2]
results['img_norm_cfg'] = dict( results['img_norm_cfg'] = dict(
mean=np.zeros(num_channels, dtype=np.float32), mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32), std=np.ones(num_channels, dtype=np.float32),
to_rgb=False) to_rgb=False)
results['img_fields'] = ['img'] results['img_fields'] = ['img']
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
return f'{self.__class__.__name__} (to_float32={self.to_float32}, '\ return f'{self.__class__.__name__} (to_float32={self.to_float32}, '\
f"color_type='{self.color_type}')" f"color_type='{self.color_type}')"
import numpy as np import numpy as np
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from shapely.geometry import LineString from shapely.geometry import LineString
@PIPELINES.register_module(force=True) @PIPELINES.register_module(force=True)
class PolygonizeLocalMapBbox(object): class PolygonizeLocalMapBbox(object):
"""Pre-Processing used by vectormapnet model. """Pre-Processing used by vectormapnet model.
Args: Args:
canvas_size (tuple or list): bev feature size canvas_size (tuple or list): bev feature size
coord_dim (int): dimension of point's coordinate coord_dim (int): dimension of point's coordinate
num_class (int): number of classes num_class (int): number of classes
threshold (float): threshold for minimum bounding box size threshold (float): threshold for minimum bounding box size
""" """
def __init__(self, def __init__(self,
canvas_size=(200, 100), canvas_size=(200, 100),
coord_dim=2, coord_dim=2,
num_class=3, num_class=3,
threshold=6/200, threshold=6/200,
): ):
self.canvas_size = np.array(canvas_size) self.canvas_size = np.array(canvas_size)
self.num_class = num_class self.num_class = num_class
# for keypoints # for keypoints
self.threshold = threshold self.threshold = threshold
self.coord_dim = coord_dim self.coord_dim = coord_dim
self.map_stop_idx = 0 self.map_stop_idx = 0
self.coord_dim_start_idx = 1 self.coord_dim_start_idx = 1
def format_polyline_map(self, vectors): def format_polyline_map(self, vectors):
polylines, polyline_masks, polyline_weights = [], [], [] polylines, polyline_masks, polyline_weights = [], [], []
# quantilize each label's lines individually. # quantilize each label's lines individually.
for label, _lines in vectors.items(): for label, _lines in vectors.items():
for polyline in _lines: for polyline in _lines:
# and pad polyline. # and pad polyline.
if label == 2: if label == 2:
polyline_weight = evaluate_line(polyline).reshape(-1) polyline_weight = evaluate_line(polyline).reshape(-1)
else: else:
polyline_weight = np.ones_like(polyline).reshape(-1) polyline_weight = np.ones_like(polyline).reshape(-1)
polyline_weight = np.pad( polyline_weight = np.pad(
polyline_weight, ((0, 1),), constant_values=1.) polyline_weight, ((0, 1),), constant_values=1.)
polyline_weight = polyline_weight/polyline_weight.sum() polyline_weight = polyline_weight/polyline_weight.sum()
# flatten and quantilized # flatten and quantilized
fpolyline = quantize_verts( fpolyline = quantize_verts(
polyline, self.canvas_size, self.coord_dim) polyline, self.canvas_size, self.coord_dim)
fpolyline = fpolyline.reshape(-1) fpolyline = fpolyline.reshape(-1)
# reindex starting from 1, and add a zero stopping token(EOS), # reindex starting from 1, and add a zero stopping token(EOS),
fpolyline = \ fpolyline = \
np.pad(fpolyline + self.coord_dim_start_idx, ((0, 1),), np.pad(fpolyline + self.coord_dim_start_idx, ((0, 1),),
constant_values=0) constant_values=0)
fpolyline_msk = np.ones(fpolyline.shape, dtype=np.bool) fpolyline_msk = np.ones(fpolyline.shape, dtype=np.bool)
polyline_masks.append(fpolyline_msk) polyline_masks.append(fpolyline_msk)
polyline_weights.append(polyline_weight) polyline_weights.append(polyline_weight)
polylines.append(fpolyline) polylines.append(fpolyline)
polyline_map = polylines polyline_map = polylines
polyline_map_mask = polyline_masks polyline_map_mask = polyline_masks
polyline_map_weights = polyline_weights polyline_map_weights = polyline_weights
return polyline_map, polyline_map_mask, polyline_map_weights return polyline_map, polyline_map_mask, polyline_map_weights
def format_keypoint(self, vectors): def format_keypoint(self, vectors):
kps, kp_labels = [], [] kps, kp_labels = [], []
qkps, qkp_masks = [], [] qkps, qkp_masks = [], []
# quantilize each label's lines individually. # quantilize each label's lines individually.
for label, _lines in vectors.items(): for label, _lines in vectors.items():
for polyline in _lines: for polyline in _lines:
kp = get_bbox(polyline, self.threshold) kp = get_bbox(polyline, self.threshold)
kps.append(kp) kps.append(kp)
kp_labels.append(label) kp_labels.append(label)
gkp = kp gkp = kp
# flatten and quantilized # flatten and quantilized
fkp = quantize_verts(gkp, self.canvas_size, self.coord_dim) fkp = quantize_verts(gkp, self.canvas_size, self.coord_dim)
fkp = fkp.reshape(-1) fkp = fkp.reshape(-1)
fkps_msk = np.ones(fkp.shape, dtype=np.bool) fkps_msk = np.ones(fkp.shape, dtype=np.bool)
qkp_masks.append(fkps_msk) qkp_masks.append(fkps_msk)
qkps.append(fkp) qkps.append(fkp)
qkps = np.stack(qkps) qkps = np.stack(qkps)
qkp_msks = np.stack(qkp_masks) qkp_msks = np.stack(qkp_masks)
# format det # format det
kps = np.stack(kps, axis=0).astype(np.float32)*self.canvas_size kps = np.stack(kps, axis=0).astype(np.float32)*self.canvas_size
kp_labels = np.array(kp_labels) kp_labels = np.array(kp_labels)
# restrict the boundary # restrict the boundary
kps[..., 0] = np.clip(kps[..., 0], 0.1, self.canvas_size[0]-0.1) kps[..., 0] = np.clip(kps[..., 0], 0.1, self.canvas_size[0]-0.1)
kps[..., 1] = np.clip(kps[..., 1], 0.1, self.canvas_size[1]-0.1) kps[..., 1] = np.clip(kps[..., 1], 0.1, self.canvas_size[1]-0.1)
# nbox, boxsize(4)*coord_dim(2) # nbox, boxsize(4)*coord_dim(2)
kps = kps.reshape(kps.shape[0], -1) kps = kps.reshape(kps.shape[0], -1)
# unflatten_seq(qkps) # unflatten_seq(qkps)
return kps, kp_labels, qkps, qkp_msks, return kps, kp_labels, qkps, qkp_msks,
def Polygonization(self, input_dict): def Polygonization(self, input_dict):
''' '''
Process vertices. Process vertices.
''' '''
vectors = input_dict['vectors'] vectors = input_dict['vectors']
n_lines = 0 n_lines = 0
for label, lines in vectors.items(): for label, lines in vectors.items():
n_lines += len(lines) n_lines += len(lines)
if not n_lines: if not n_lines:
input_dict['polys'] = [] input_dict['polys'] = []
return input_dict return input_dict
polyline_map, polyline_map_mask, polyline_map_weight = \ polyline_map, polyline_map_mask, polyline_map_weight = \
self.format_polyline_map(vectors) self.format_polyline_map(vectors)
keypoint, keypoint_label, qkeypoint, qkeypoint_mask = \ keypoint, keypoint_label, qkeypoint, qkeypoint_mask = \
self.format_keypoint(vectors) self.format_keypoint(vectors)
# gather # gather
polys = { polys = {
# for det # for det
'keypoint': keypoint, 'keypoint': keypoint,
'det_label': keypoint_label, 'det_label': keypoint_label,
# for gen # for gen
'gen_label': keypoint_label, 'gen_label': keypoint_label,
'qkeypoint': qkeypoint, 'qkeypoint': qkeypoint,
'qkeypoint_mask': qkeypoint_mask, 'qkeypoint_mask': qkeypoint_mask,
'polylines': polyline_map, # List[array] 'polylines': polyline_map, # List[array]
'polyline_masks': polyline_map_mask, # List[array] 'polyline_masks': polyline_map_mask, # List[array]
'polyline_weights': polyline_map_weight 'polyline_weights': polyline_map_weight
} }
# Format outputs # Format outputs
input_dict['polys'] = polys input_dict['polys'] = polys
return input_dict return input_dict
def __call__(self, input_dict): def __call__(self, input_dict):
input_dict = self.Polygonization(input_dict) input_dict = self.Polygonization(input_dict)
return input_dict return input_dict
def evaluate_line(polyline): def evaluate_line(polyline):
edge = np.linalg.norm(polyline[1:] - polyline[:-1], axis=-1) edge = np.linalg.norm(polyline[1:] - polyline[:-1], axis=-1)
start_end_weight = edge[(0, -1), ].copy() start_end_weight = edge[(0, -1), ].copy()
mid_weight = (edge[:-1] + edge[1:]) * .5 mid_weight = (edge[:-1] + edge[1:]) * .5
pts_weight = np.concatenate( pts_weight = np.concatenate(
(start_end_weight[:1], mid_weight, start_end_weight[-1:])) (start_end_weight[:1], mid_weight, start_end_weight[-1:]))
denominator = pts_weight.sum() denominator = pts_weight.sum()
denominator = 1 if denominator == 0 else denominator denominator = 1 if denominator == 0 else denominator
pts_weight /= denominator pts_weight /= denominator
# add weights for stop index # add weights for stop index
pts_weight = np.repeat(pts_weight, 2)/2 pts_weight = np.repeat(pts_weight, 2)/2
pts_weight = np.pad(pts_weight, ((0, 1)), pts_weight = np.pad(pts_weight, ((0, 1)),
constant_values=1/(len(polyline)*2)) constant_values=1/(len(polyline)*2))
return pts_weight return pts_weight
def quantize_verts(verts, canvas_size, coord_dim): def quantize_verts(verts, canvas_size, coord_dim):
"""Convert vertices from its original range ([-1,1]) to discrete values in [0, n_bits**2 - 1]. """Convert vertices from its original range ([-1,1]) to discrete values in [0, n_bits**2 - 1].
Args: Args:
verts (array): vertices coordinates, shape (seqlen, coords_dim) verts (array): vertices coordinates, shape (seqlen, coords_dim)
canvas_size (tuple): bev feature size canvas_size (tuple): bev feature size
coord_dim (int): dimension of point coordinates coord_dim (int): dimension of point coordinates
Returns: Returns:
quantized_verts (array): quantized vertices, shape (seqlen, coords_dim) quantized_verts (array): quantized vertices, shape (seqlen, coords_dim)
""" """
min_range = 0 min_range = 0
max_range = 1 max_range = 1
range_quantize = np.array(canvas_size) - 1 # (0-199) = 200 range_quantize = np.array(canvas_size) - 1 # (0-199) = 200
verts_ratio = (verts[:, :coord_dim] - min_range) / ( verts_ratio = (verts[:, :coord_dim] - min_range) / (
max_range - min_range) max_range - min_range)
verts_quantize = verts_ratio * range_quantize[:coord_dim] verts_quantize = verts_ratio * range_quantize[:coord_dim]
return verts_quantize.astype('int32') return verts_quantize.astype('int32')
def get_bbox(polyline, threshold): def get_bbox(polyline, threshold):
"""Convert vertices from its original range ([-1,1]) to discrete values in [0, n_bits**2 - 1]. """Convert vertices from its original range ([-1,1]) to discrete values in [0, n_bits**2 - 1].
Args: Args:
polyline (array): point coordinates, shape (seqlen, 2) polyline (array): point coordinates, shape (seqlen, 2)
threshold (float): threshold for minimum bbox size threshold (float): threshold for minimum bbox size
Returns: Returns:
bbox (array): bounding box in xyxy format, shape (2, 2) bbox (array): bounding box in xyxy format, shape (2, 2)
""" """
eps = 1e-4 eps = 1e-4
polyline = LineString(polyline) polyline = LineString(polyline)
bbox = polyline.bounds bbox = polyline.bounds
minx, miny, maxx, maxy = bbox minx, miny, maxx, maxy = bbox
W, H = maxx-minx, maxy-miny W, H = maxx-minx, maxy-miny
if W < threshold or H < threshold: if W < threshold or H < threshold:
remain = max((threshold - min(W, H))/2, eps) remain = max((threshold - min(W, H))/2, eps)
bbox = polyline.buffer(remain).envelope.bounds bbox = polyline.buffer(remain).envelope.bounds
minx, miny, maxx, maxy = bbox minx, miny, maxx, maxy = bbox
bbox_np = np.array([[minx, miny], [maxx, maxy]]) bbox_np = np.array([[minx, miny], [maxx, maxy]])
bbox_np = np.clip(bbox_np, 0., 1.) bbox_np = np.clip(bbox_np, 0., 1.)
return bbox_np return bbox_np
\ No newline at end of file
import numpy as np import numpy as np
import mmcv import mmcv
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module(force=True) @PIPELINES.register_module(force=True)
class Normalize3D(object): class Normalize3D(object):
"""Normalize the image. """Normalize the image.
Added key is "img_norm_cfg". Added key is "img_norm_cfg".
Args: Args:
mean (sequence): Mean values of 3 channels. mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels. std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB, to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true. default is true.
""" """
def __init__(self, mean, std, to_rgb=True): def __init__(self, mean, std, to_rgb=True):
self.mean = np.array(mean, dtype=np.float32) self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32) self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb self.to_rgb = to_rgb
def __call__(self, results): def __call__(self, results):
"""Call function to normalize images. """Call function to normalize images.
Args: Args:
results (dict): Result dict from loading pipeline. results (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Normalized results, 'img_norm_cfg' key is added into dict: Normalized results, 'img_norm_cfg' key is added into
result dict. result dict.
""" """
for key in results.get('img_fields', ['img']): for key in results.get('img_fields', ['img']):
results[key] = [mmcv.imnormalize( results[key] = [mmcv.imnormalize(
img, self.mean, self.std, self.to_rgb) for img in results[key]] img, self.mean, self.std, self.to_rgb) for img in results[key]]
results['img_norm_cfg'] = dict( results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb) mean=self.mean, std=self.std, to_rgb=self.to_rgb)
return results return results
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})' repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
return repr_str return repr_str
@PIPELINES.register_module(force=True) @PIPELINES.register_module(force=True)
class PadMultiViewImages(object): class PadMultiViewImages(object):
"""Pad multi-view images and change intrinsics """Pad multi-view images and change intrinsics
There are two padding modes: (1) pad to a fixed size and (2) pad to the There are two padding modes: (1) pad to a fixed size and (2) pad to the
minimum size that is divisible by some number. minimum size that is divisible by some number.
Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
If set `change_intrinsics=True`, key 'cam_intrinsics' and 'ego2img' will be changed. If set `change_intrinsics=True`, key 'cam_intrinsics' and 'ego2img' will be changed.
Args: Args:
size (tuple, optional): Fixed padding size, (h, w). size (tuple, optional): Fixed padding size, (h, w).
size_divisor (int, optional): The divisor of padded size. size_divisor (int, optional): The divisor of padded size.
pad_val (float, optional): Padding value, 0 by default. pad_val (float, optional): Padding value, 0 by default.
change_intrinsics (bool): whether to update intrinsics. change_intrinsics (bool): whether to update intrinsics.
""" """
def __init__(self, size=None, size_divisor=None, pad_val=0, change_intrinsics=False): def __init__(self, size=None, size_divisor=None, pad_val=0, change_intrinsics=False):
self.size = size self.size = size
self.size_divisor = size_divisor self.size_divisor = size_divisor
self.pad_val = pad_val self.pad_val = pad_val
# only one of size and size_divisor should be valid # only one of size and size_divisor should be valid
assert size is not None or size_divisor is not None assert size is not None or size_divisor is not None
assert size is None or size_divisor is None assert size is None or size_divisor is None
self.change_intrinsics = change_intrinsics self.change_intrinsics = change_intrinsics
def _pad_img(self, results): def _pad_img(self, results):
"""Pad images according to ``self.size``.""" """Pad images according to ``self.size``."""
original_shape = [img.shape for img in results['img']] original_shape = [img.shape for img in results['img']]
for key in results.get('img_fields', ['img']): for key in results.get('img_fields', ['img']):
if self.size is not None: if self.size is not None:
padded_img = [mmcv.impad( padded_img = [mmcv.impad(
img, shape=self.size, pad_val=self.pad_val) for img in results[key]] img, shape=self.size, pad_val=self.pad_val) for img in results[key]]
elif self.size_divisor is not None: elif self.size_divisor is not None:
padded_img = [mmcv.impad_to_multiple( padded_img = [mmcv.impad_to_multiple(
img, self.size_divisor, pad_val=self.pad_val) for img in results[key]] img, self.size_divisor, pad_val=self.pad_val) for img in results[key]]
results[key] = padded_img results[key] = padded_img
if self.change_intrinsics: if self.change_intrinsics:
post_intrinsics, post_ego2imgs = [], [] post_intrinsics, post_ego2imgs = [], []
for img, oshape, cam_intrinsic, ego2img in zip(results['img'], \ for img, oshape, cam_intrinsic, ego2img in zip(results['img'], \
original_shape, results['cam_intrinsics'], results['ego2img']): original_shape, results['cam_intrinsics'], results['ego2img']):
scaleW = img.shape[1] / oshape[1] scaleW = img.shape[1] / oshape[1]
scaleH = img.shape[0] / oshape[0] scaleH = img.shape[0] / oshape[0]
rot_resize_matrix = np.array([ rot_resize_matrix = np.array([
[scaleW, 0, 0, 0], [scaleW, 0, 0, 0],
[0, scaleH, 0, 0], [0, scaleH, 0, 0],
[0, 0, 1, 0], [0, 0, 1, 0],
[0, 0, 0, 1]]) [0, 0, 0, 1]])
post_intrinsic = rot_resize_matrix[:3, :3] @ cam_intrinsic post_intrinsic = rot_resize_matrix[:3, :3] @ cam_intrinsic
post_ego2img = rot_resize_matrix @ ego2img post_ego2img = rot_resize_matrix @ ego2img
post_intrinsics.append(post_intrinsic) post_intrinsics.append(post_intrinsic)
post_ego2imgs.append(post_ego2img) post_ego2imgs.append(post_ego2img)
results.update({ results.update({
'cam_intrinsics': post_intrinsics, 'cam_intrinsics': post_intrinsics,
'ego2img': post_ego2imgs, 'ego2img': post_ego2imgs,
}) })
results['img_shape'] = [img.shape for img in padded_img] results['img_shape'] = [img.shape for img in padded_img]
results['img_fixed_size'] = self.size results['img_fixed_size'] = self.size
results['img_size_divisor'] = self.size_divisor results['img_size_divisor'] = self.size_divisor
def __call__(self, results): def __call__(self, results):
"""Call function to pad images, masks, semantic segmentation maps. """Call function to pad images, masks, semantic segmentation maps.
Args: Args:
results (dict): Result dict from loading pipeline. results (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Updated result dict. dict: Updated result dict.
""" """
self._pad_img(results) self._pad_img(results)
return results return results
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, ' repr_str += f'(size={self.size}, '
repr_str += f'size_divisor={self.size_divisor}, ' repr_str += f'size_divisor={self.size_divisor}, '
repr_str += f'pad_val={self.pad_val})' repr_str += f'pad_val={self.pad_val})'
repr_str += f'change_intrinsics={self.change_intrinsics})' repr_str += f'change_intrinsics={self.change_intrinsics})'
return repr_str return repr_str
@PIPELINES.register_module(force=True) @PIPELINES.register_module(force=True)
class ResizeMultiViewImages(object): class ResizeMultiViewImages(object):
"""Resize mulit-view images and change intrinsics """Resize mulit-view images and change intrinsics
If set `change_intrinsics=True`, key 'cam_intrinsics' and 'ego2img' will be changed If set `change_intrinsics=True`, key 'cam_intrinsics' and 'ego2img' will be changed
Args: Args:
size (tuple, optional): resize target size, (h, w). size (tuple, optional): resize target size, (h, w).
change_intrinsics (bool): whether to update intrinsics. change_intrinsics (bool): whether to update intrinsics.
""" """
def __init__(self, size, change_intrinsics=True): def __init__(self, size, change_intrinsics=True):
self.size = size self.size = size
self.change_intrinsics = change_intrinsics self.change_intrinsics = change_intrinsics
def __call__(self, results:dict): def __call__(self, results:dict):
new_imgs, post_intrinsics, post_ego2imgs = [], [], [] new_imgs, post_intrinsics, post_ego2imgs = [], [], []
for img, cam_intrinsic, ego2img in zip(results['img'], \ for img, cam_intrinsic, ego2img in zip(results['img'], \
results['cam_intrinsics'], results['ego2img']): results['cam_intrinsics'], results['ego2img']):
tmp, scaleW, scaleH = mmcv.imresize(img, tmp, scaleW, scaleH = mmcv.imresize(img,
# NOTE: mmcv.imresize expect (w, h) shape # NOTE: mmcv.imresize expect (w, h) shape
(self.size[1], self.size[0]), (self.size[1], self.size[0]),
return_scale=True) return_scale=True)
new_imgs.append(tmp) new_imgs.append(tmp)
rot_resize_matrix = np.array([ rot_resize_matrix = np.array([
[scaleW, 0, 0, 0], [scaleW, 0, 0, 0],
[0, scaleH, 0, 0], [0, scaleH, 0, 0],
[0, 0, 1, 0], [0, 0, 1, 0],
[0, 0, 0, 1]]) [0, 0, 0, 1]])
post_intrinsic = rot_resize_matrix[:3, :3] @ cam_intrinsic post_intrinsic = rot_resize_matrix[:3, :3] @ cam_intrinsic
post_ego2img = rot_resize_matrix @ ego2img post_ego2img = rot_resize_matrix @ ego2img
post_intrinsics.append(post_intrinsic) post_intrinsics.append(post_intrinsic)
post_ego2imgs.append(post_ego2img) post_ego2imgs.append(post_ego2img)
results['img'] = new_imgs results['img'] = new_imgs
results['img_shape'] = [img.shape for img in new_imgs] results['img_shape'] = [img.shape for img in new_imgs]
if self.change_intrinsics: if self.change_intrinsics:
results.update({ results.update({
'cam_intrinsics': post_intrinsics, 'cam_intrinsics': post_intrinsics,
'ego2img': post_ego2imgs, 'ego2img': post_ego2imgs,
}) })
return results return results
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, ' repr_str += f'(size={self.size}, '
repr_str += f'change_intrinsics={self.change_intrinsics})' repr_str += f'change_intrinsics={self.change_intrinsics})'
return repr_str return repr_str
\ No newline at end of file
import numpy as np import numpy as np
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from shapely.geometry import LineString from shapely.geometry import LineString
from numpy.typing import NDArray from numpy.typing import NDArray
from typing import List, Tuple, Union, Dict from typing import List, Tuple, Union, Dict
@PIPELINES.register_module(force=True) @PIPELINES.register_module(force=True)
class VectorizeMap(object): class VectorizeMap(object):
"""Generate vectoized map and put into `semantic_mask` key. """Generate vectoized map and put into `semantic_mask` key.
Concretely, shapely geometry objects are converted into sample points (ndarray). Concretely, shapely geometry objects are converted into sample points (ndarray).
We use args `sample_num`, `sample_dist`, `simplify` to specify sampling method. We use args `sample_num`, `sample_dist`, `simplify` to specify sampling method.
Args: Args:
roi_size (tuple or list): bev range . roi_size (tuple or list): bev range .
normalize (bool): whether to normalize points to range (0, 1). normalize (bool): whether to normalize points to range (0, 1).
coords_dim (int): dimension of point coordinates. coords_dim (int): dimension of point coordinates.
simplify (bool): whether to use simpily function. If true, `sample_num` \ simplify (bool): whether to use simpily function. If true, `sample_num` \
and `sample_dist` will be ignored. and `sample_dist` will be ignored.
sample_num (int): number of points to interpolate from a polyline. Set to -1 to ignore. sample_num (int): number of points to interpolate from a polyline. Set to -1 to ignore.
sample_dist (float): interpolate distance. Set to -1 to ignore. sample_dist (float): interpolate distance. Set to -1 to ignore.
""" """
def __init__(self, def __init__(self,
roi_size: Union[Tuple, List], roi_size: Union[Tuple, List],
normalize: bool, normalize: bool,
coords_dim: int, coords_dim: int,
simplify: bool=False, simplify: bool=False,
sample_num: int=-1, sample_num: int=-1,
sample_dist: float=-1, sample_dist: float=-1,
): ):
self.coords_dim = coords_dim self.coords_dim = coords_dim
self.sample_num = sample_num self.sample_num = sample_num
self.sample_dist = sample_dist self.sample_dist = sample_dist
self.roi_size = np.array(roi_size) self.roi_size = np.array(roi_size)
self.normalize = normalize self.normalize = normalize
self.simplify = simplify self.simplify = simplify
self.sample_fn = None self.sample_fn = None
if sample_dist > 0: if sample_dist > 0:
assert sample_num < 0 and not simplify assert sample_num < 0 and not simplify
self.sample_fn = self.interp_fixed_dist self.sample_fn = self.interp_fixed_dist
if sample_num > 0: if sample_num > 0:
assert sample_dist < 0 and not simplify assert sample_dist < 0 and not simplify
self.sample_fn = self.interp_fixed_num self.sample_fn = self.interp_fixed_num
def interp_fixed_num(self, line: LineString) -> NDArray: def interp_fixed_num(self, line: LineString) -> NDArray:
''' Interpolate a line to fixed number of points. ''' Interpolate a line to fixed number of points.
Args: Args:
line (LineString): line line (LineString): line
Returns: Returns:
points (array): interpolated points, shape (N, 2) points (array): interpolated points, shape (N, 2)
''' '''
distances = np.linspace(0, line.length, self.sample_num) distances = np.linspace(0, line.length, self.sample_num)
sampled_points = np.array([list(line.interpolate(distance).coords) sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze() for distance in distances]).squeeze()
return sampled_points return sampled_points
def interp_fixed_dist(self, line: LineString) -> NDArray: def interp_fixed_dist(self, line: LineString) -> NDArray:
''' Interpolate a line at fixed interval. ''' Interpolate a line at fixed interval.
Args: Args:
line (LineString): line line (LineString): line
Returns: Returns:
points (array): interpolated points, shape (N, 2) points (array): interpolated points, shape (N, 2)
''' '''
distances = list(np.arange(self.sample_dist, line.length, self.sample_dist)) distances = list(np.arange(self.sample_dist, line.length, self.sample_dist))
# make sure to sample at least two points when sample_dist > line.length # make sure to sample at least two points when sample_dist > line.length
distances = [0,] + distances + [line.length,] distances = [0,] + distances + [line.length,]
sampled_points = np.array([list(line.interpolate(distance).coords) sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze() for distance in distances]).squeeze()
return sampled_points return sampled_points
def get_vectorized_lines(self, map_geoms: Dict) -> Dict: def get_vectorized_lines(self, map_geoms: Dict) -> Dict:
''' Vectorize map elements. Iterate over the input dict and apply the ''' Vectorize map elements. Iterate over the input dict and apply the
specified sample funcion. specified sample funcion.
Args: Args:
line (LineString): line line (LineString): line
Returns: Returns:
vectors (array): dict of vectorized map elements. vectors (array): dict of vectorized map elements.
''' '''
vectors = {} vectors = {}
for label, geom_list in map_geoms.items(): for label, geom_list in map_geoms.items():
vectors[label] = [] vectors[label] = []
for geom in geom_list: for geom in geom_list:
if geom.geom_type == 'LineString': if geom.geom_type == 'LineString':
geom = LineString(np.array(geom.coords)[:, :self.coords_dim]) geom = LineString(np.array(geom.coords)[:, :self.coords_dim])
if self.simplify: if self.simplify:
line = geom.simplify(0.2, preserve_topology=True) line = geom.simplify(0.2, preserve_topology=True)
line = np.array(line.coords) line = np.array(line.coords)
elif self.sample_fn: elif self.sample_fn:
line = self.sample_fn(geom) line = self.sample_fn(geom)
else: else:
line = np.array(line.coords) line = np.array(line.coords)
if self.normalize: if self.normalize:
line = self.normalize_line(line) line = self.normalize_line(line)
vectors[label].append(line) vectors[label].append(line)
elif geom.geom_type == 'Polygon': elif geom.geom_type == 'Polygon':
# polygon objects will not be vectorized # polygon objects will not be vectorized
continue continue
else: else:
raise ValueError('map geoms must be either LineString or Polygon!') raise ValueError('map geoms must be either LineString or Polygon!')
return vectors return vectors
def normalize_line(self, line: NDArray) -> NDArray: def normalize_line(self, line: NDArray) -> NDArray:
''' Convert points to range (0, 1). ''' Convert points to range (0, 1).
Args: Args:
line (LineString): line line (LineString): line
Returns: Returns:
normalized (array): normalized points. normalized (array): normalized points.
''' '''
origin = -np.array([self.roi_size[0]/2, self.roi_size[1]/2]) origin = -np.array([self.roi_size[0]/2, self.roi_size[1]/2])
line[:, :2] = line[:, :2] - origin line[:, :2] = line[:, :2] - origin
# transform from range [0, 1] to (0, 1) # transform from range [0, 1] to (0, 1)
eps = 2 eps = 2
line[:, :2] = line[:, :2] / (self.roi_size + eps) line[:, :2] = line[:, :2] / (self.roi_size + eps)
return line return line
def __call__(self, input_dict): def __call__(self, input_dict):
map_geoms = input_dict['map_geoms'] map_geoms = input_dict['map_geoms']
input_dict['vectors'] = self.get_vectorized_lines(map_geoms) input_dict['vectors'] = self.get_vectorized_lines(map_geoms)
return input_dict return input_dict
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(simplify={self.simplify}, ' repr_str += f'(simplify={self.simplify}, '
repr_str += f'sample_num={self.sample_num}), ' repr_str += f'sample_num={self.sample_num}), '
repr_str += f'sample_dist={self.sample_dist}), ' repr_str += f'sample_dist={self.sample_dist}), '
repr_str += f'roi_size={self.roi_size})' repr_str += f'roi_size={self.roi_size})'
repr_str += f'normalize={self.normalize})' repr_str += f'normalize={self.normalize})'
repr_str += f'coords_dim={self.coords_dim})' repr_str += f'coords_dim={self.coords_dim})'
return repr_str return repr_str
\ No newline at end of file
from .backbones import * from .backbones import *
from .heads import * from .heads import *
from .losses import * from .losses import *
from .mapers import * from .mapers import *
from .transformer_utils import * from .transformer_utils import *
from .assigner import * from .assigner import *
from .assigner import HungarianLinesAssigner from .assigner import HungarianLinesAssigner
from .match_cost import MapQueriesCost, BBoxLogitsCost, DynamicLinesCost, IoUCostC, BBoxCostC, LinesCost, LinesFixNumChamferCost, ClsSigmoidCost from .match_cost import MapQueriesCost, BBoxLogitsCost, DynamicLinesCost, IoUCostC, BBoxCostC, LinesCost, LinesFixNumChamferCost, ClsSigmoidCost
import torch import torch
from mmdet.core.bbox.builder import BBOX_ASSIGNERS from mmdet.core.bbox.builder import BBOX_ASSIGNERS
from mmdet.core.bbox.assigners import AssignResult from mmdet.core.bbox.assigners import AssignResult
from mmdet.core.bbox.assigners import BaseAssigner from mmdet.core.bbox.assigners import BaseAssigner
from mmdet.core.bbox.match_costs import build_match_cost from mmdet.core.bbox.match_costs import build_match_cost
try: try:
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
except ImportError: except ImportError:
linear_sum_assignment = None linear_sum_assignment = None
@BBOX_ASSIGNERS.register_module() @BBOX_ASSIGNERS.register_module()
class HungarianLinesAssigner(BaseAssigner): class HungarianLinesAssigner(BaseAssigner):
""" """
Computes one-to-one matching between predictions and ground truth. Computes one-to-one matching between predictions and ground truth.
This class computes an assignment between the targets and the predictions This class computes an assignment between the targets and the predictions
based on the costs. The costs are weighted sum of three components: based on the costs. The costs are weighted sum of three components:
classification cost and regression L1 cost. The classification cost and regression L1 cost. The
targets don't include the no_object, so generally there are more targets don't include the no_object, so generally there are more
predictions than targets. After the one-to-one matching, the un-matched predictions than targets. After the one-to-one matching, the un-matched
are treated as backgrounds. Thus each query prediction will be assigned are treated as backgrounds. Thus each query prediction will be assigned
with `0` or a positive integer indicating the ground truth index: with `0` or a positive integer indicating the ground truth index:
- 0: negative sample, no assigned gt - 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt - positive integer: positive sample, index (1-based) of assigned gt
Args: Args:
cls_weight (int | float, optional): The scale factor for classification cls_weight (int | float, optional): The scale factor for classification
cost. Default 1.0. cost. Default 1.0.
bbox_weight (int | float, optional): The scale factor for regression bbox_weight (int | float, optional): The scale factor for regression
L1 cost. Default 1.0. L1 cost. Default 1.0.
""" """
def __init__(self, def __init__(self,
cost=dict( cost=dict(
type='MapQueriesCost', type='MapQueriesCost',
cls_cost=dict(type='ClassificationCost', weight=1.), cls_cost=dict(type='ClassificationCost', weight=1.),
reg_cost=dict(type='LinesCost', weight=1.0), reg_cost=dict(type='LinesCost', weight=1.0),
), ),
pc_range=None, pc_range=None,
**kwargs): **kwargs):
self.pc_range = pc_range self.pc_range = pc_range
self.cost = build_match_cost(cost) self.cost = build_match_cost(cost)
def assign(self, def assign(self,
preds: dict, preds: dict,
gts: dict, gts: dict,
gt_bboxes_ignore=None, gt_bboxes_ignore=None,
eps=1e-7): eps=1e-7):
""" """
Computes one-to-one matching based on the weighted costs. Computes one-to-one matching based on the weighted costs.
This method assign each query prediction to a ground truth or This method assign each query prediction to a ground truth or
background. The `assigned_gt_inds` with -1 means don't care, background. The `assigned_gt_inds` with -1 means don't care,
0 means negative sample, and positive number is the index (1-based) 0 means negative sample, and positive number is the index (1-based)
of assigned gt. of assigned gt.
The assignment is done in the following steps, the order matters. The assignment is done in the following steps, the order matters.
1. assign every prediction to -1 1. assign every prediction to -1
2. compute the weighted costs 2. compute the weighted costs
3. do Hungarian matching on CPU based on the costs 3. do Hungarian matching on CPU based on the costs
4. assign all to 0 (background) first, then for each matched pair 4. assign all to 0 (background) first, then for each matched pair
between predictions and gts, treat this prediction as foreground between predictions and gts, treat this prediction as foreground
and assign the corresponding gt index (plus 1) to it. and assign the corresponding gt index (plus 1) to it.
Args: Args:
lines_pred (Tensor): predicted normalized lines: lines_pred (Tensor): predicted normalized lines:
[num_query, num_points, 2] [num_query, num_points, 2]
cls_pred (Tensor): Predicted classification logits, shape cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class]. [num_query, num_class].
Note: when compute bbox l1 loss, velocity is not included!! Note: when compute bbox l1 loss, velocity is not included!!
lines_gt (Tensor): Ground truth lines lines_gt (Tensor): Ground truth lines
[num_gt, num_points, 2]. [num_gt, num_points, 2].
labels_gt (Tensor): Label of `gt_bboxes`, shape (num_gt,). labels_gt (Tensor): Label of `gt_bboxes`, shape (num_gt,).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`. Default None. labelled as `ignored`. Default None.
eps (int | float, optional): A value added to the denominator for eps (int | float, optional): A value added to the denominator for
numerical stability. Default 1e-7. numerical stability. Default 1e-7.
Returns: Returns:
:obj:`AssignResult`: The assigned result. :obj:`AssignResult`: The assigned result.
""" """
assert gt_bboxes_ignore is None, \ assert gt_bboxes_ignore is None, \
'Only case when gt_bboxes_ignore is None is supported.' 'Only case when gt_bboxes_ignore is None is supported.'
num_gts, num_lines = gts['lines'].size(0), preds['lines'].size(0) num_gts, num_lines = gts['lines'].size(0), preds['lines'].size(0)
# 1. assign -1 by default # 1. assign -1 by default
assigned_gt_inds = \ assigned_gt_inds = \
preds['lines'].new_full((num_lines,), -1, dtype=torch.long) preds['lines'].new_full((num_lines,), -1, dtype=torch.long)
assigned_labels = \ assigned_labels = \
preds['lines'].new_full((num_lines,), -1, dtype=torch.long) preds['lines'].new_full((num_lines,), -1, dtype=torch.long)
if num_gts == 0 or num_lines == 0: if num_gts == 0 or num_lines == 0:
# No ground truth or boxes, return empty assignment # No ground truth or boxes, return empty assignment
if num_gts == 0: if num_gts == 0:
# No ground truth, assign all to background # No ground truth, assign all to background
assigned_gt_inds[:] = 0 assigned_gt_inds[:] = 0
return AssignResult( return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels) num_gts, assigned_gt_inds, None, labels=assigned_labels)
# 2. compute the weighted costs # 2. compute the weighted costs
cost = self.cost(preds, gts) cost = self.cost(preds, gts)
# 3. do Hungarian matching on CPU using linear_sum_assignment # 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu().numpy() cost = cost.detach().cpu().numpy()
if linear_sum_assignment is None: if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" ' raise ImportError('Please run "pip install scipy" '
'to install scipy first.') 'to install scipy first.')
try: try:
matched_row_inds, matched_col_inds = linear_sum_assignment(cost) matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
except: except:
print('cost max{}, min{}'.format(cost.max(), cost.min())) print('cost max{}, min{}'.format(cost.max(), cost.min()))
import ipdb; ipdb.set_trace() import ipdb; ipdb.set_trace()
matched_row_inds = torch.from_numpy(matched_row_inds).to( matched_row_inds = torch.from_numpy(matched_row_inds).to(
preds['lines'].device) preds['lines'].device)
matched_col_inds = torch.from_numpy(matched_col_inds).to( matched_col_inds = torch.from_numpy(matched_col_inds).to(
preds['lines'].device) preds['lines'].device)
# 4. assign backgrounds and foregrounds # 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first # assign all indices to backgrounds first
assigned_gt_inds[:] = 0 assigned_gt_inds[:] = 0
# assign foregrounds based on matching results # assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gts['labels'][matched_col_inds] assigned_labels[matched_row_inds] = gts['labels'][matched_col_inds]
return AssignResult( return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels) num_gts, assigned_gt_inds, None, labels=assigned_labels)
\ No newline at end of file
import torch import torch
from mmdet.core.bbox.match_costs.builder import MATCH_COST from mmdet.core.bbox.match_costs.builder import MATCH_COST
from mmdet.core.bbox.match_costs import build_match_cost from mmdet.core.bbox.match_costs import build_match_cost
from mmdet.core.bbox.iou_calculators import bbox_overlaps from mmdet.core.bbox.iou_calculators import bbox_overlaps
from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy
def chamfer_distance(pred, gt): def chamfer_distance(pred, gt):
''' '''
Args: Args:
pred: [num_points, 2] pred: [num_points, 2]
gt: [num_gt, 2] gt: [num_gt, 2]
Out: torch.FloatTensor of shape (1, ) Out: torch.FloatTensor of shape (1, )
''' '''
# [num_points, num_gt] # [num_points, num_gt]
dist_mat = torch.cdist(pred, gt, p=2) dist_mat = torch.cdist(pred, gt, p=2)
# [num_points] # [num_points]
dist_pred, _ = torch.min(dist_mat, dim=-1) dist_pred, _ = torch.min(dist_mat, dim=-1)
dist_pred = torch.clamp(dist_pred, max=2.0) dist_pred = torch.clamp(dist_pred, max=2.0)
dist_pred = dist_pred.mean() dist_pred = dist_pred.mean()
dist_gt, _ = torch.min(dist_mat, dim=0) dist_gt, _ = torch.min(dist_mat, dim=0)
dist_gt = torch.clamp(dist_gt, max=2.0) dist_gt = torch.clamp(dist_gt, max=2.0)
dist_gt = dist_gt.mean() dist_gt = dist_gt.mean()
dist = dist_pred + dist_gt dist = dist_pred + dist_gt
return dist return dist
@MATCH_COST.register_module() @MATCH_COST.register_module()
class ClsSigmoidCost: class ClsSigmoidCost:
"""ClsSoftmaxCost. """ClsSoftmaxCost.
Args: Args:
weight (int | float, optional): loss_weight weight (int | float, optional): loss_weight
""" """
def __init__(self, weight=1.): def __init__(self, weight=1.):
self.weight = weight self.weight = weight
def __call__(self, cls_pred, gt_labels): def __call__(self, cls_pred, gt_labels):
""" """
Args: Args:
cls_pred (Tensor): Predicted classification logits, shape cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class]. [num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns: Returns:
torch.Tensor: cls_cost value with weight torch.Tensor: cls_cost value with weight
""" """
# Following the official DETR repo, contrary to the loss that # Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label]. # NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching, # The 1 is a constant that doesn't change the matching,
# so it can be omitted. # so it can be omitted.
cls_score = cls_pred.sigmoid() cls_score = cls_pred.sigmoid()
cls_cost = -cls_score[:, gt_labels] cls_cost = -cls_score[:, gt_labels]
return cls_cost * self.weight return cls_cost * self.weight
@MATCH_COST.register_module() @MATCH_COST.register_module()
class LinesFixNumChamferCost(object): class LinesFixNumChamferCost(object):
"""BBox3DL1Cost. """BBox3DL1Cost.
Args: Args:
weight (int | float, optional): loss_weight weight (int | float, optional): loss_weight
""" """
def __init__(self, weight=1.): def __init__(self, weight=1.):
self.weight = weight self.weight = weight
def __call__(self, lines_pred, gt_lines): def __call__(self, lines_pred, gt_lines):
""" """
Args: Args:
lines_pred (Tensor): predicted normalized lines: lines_pred (Tensor): predicted normalized lines:
[num_query, num_points, 2] [num_query, num_points, 2]
gt_lines (Tensor): Ground truth lines gt_lines (Tensor): Ground truth lines
[num_gt, num_points, 2] [num_gt, num_points, 2]
Returns: Returns:
torch.Tensor: reg_cost value with weight torch.Tensor: reg_cost value with weight
shape [num_pred, num_gt] shape [num_pred, num_gt]
""" """
num_gts, num_bboxes = gt_lines.size(0), lines_pred.size(0) num_gts, num_bboxes = gt_lines.size(0), lines_pred.size(0)
dist_mat = lines_pred.new_full((num_bboxes, num_gts), dist_mat = lines_pred.new_full((num_bboxes, num_gts),
1.0,) 1.0,)
for i in range(num_bboxes): for i in range(num_bboxes):
for j in range(num_gts): for j in range(num_gts):
dist_mat[i, j] = chamfer_distance( dist_mat[i, j] = chamfer_distance(
lines_pred[i], gt_lines[j]) lines_pred[i], gt_lines[j])
return dist_mat * self.weight return dist_mat * self.weight
@MATCH_COST.register_module() @MATCH_COST.register_module()
class LinesCost(object): class LinesCost(object):
"""LinesL1Cost. """LinesL1Cost.
Args: Args:
weight (int | float, optional): loss_weight weight (int | float, optional): loss_weight
""" """
def __init__(self, weight=1.): def __init__(self, weight=1.):
self.weight = weight self.weight = weight
def __call__(self, lines_pred, gt_lines, **kwargs): def __call__(self, lines_pred, gt_lines, **kwargs):
""" """
Args: Args:
lines_pred (Tensor): predicted normalized lines: lines_pred (Tensor): predicted normalized lines:
[num_query, num_points, 2] [num_query, num_points, 2]
gt_lines (Tensor): Ground truth lines gt_lines (Tensor): Ground truth lines
[num_gt, num_points, 2] [num_gt, num_points, 2]
Returns: Returns:
torch.Tensor: reg_cost value with weight torch.Tensor: reg_cost value with weight
shape [num_pred, num_gt] shape [num_pred, num_gt]
""" """
gt_revser = torch.flip(gt_lines, dims=[-2]) gt_revser = torch.flip(gt_lines, dims=[-2])
gt_revser_flat = gt_revser.flatten(1, 2) gt_revser_flat = gt_revser.flatten(1, 2)
pred_flat = lines_pred.flatten(1, 2) pred_flat = lines_pred.flatten(1, 2)
gt_flat = gt_lines.flatten(1, 2) gt_flat = gt_lines.flatten(1, 2)
div_ = pred_flat.size(-1) div_ = pred_flat.size(-1)
dist_mat = torch.cdist(pred_flat, gt_flat, p=1) / div_ dist_mat = torch.cdist(pred_flat, gt_flat, p=1) / div_
return dist_mat * self.weight return dist_mat * self.weight
@MATCH_COST.register_module() @MATCH_COST.register_module()
class BBoxCostC: class BBoxCostC:
"""BBoxL1Cost. """BBoxL1Cost.
Args: Args:
weight (int | float, optional): loss_weight weight (int | float, optional): loss_weight
box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN
Examples: Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost >>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost
>>> import torch >>> import torch
>>> self = BBoxL1Cost() >>> self = BBoxL1Cost()
>>> bbox_pred = torch.rand(1, 4) >>> bbox_pred = torch.rand(1, 4)
>>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
>>> factor = torch.tensor([10, 8, 10, 8]) >>> factor = torch.tensor([10, 8, 10, 8])
>>> self(bbox_pred, gt_bboxes, factor) >>> self(bbox_pred, gt_bboxes, factor)
tensor([[1.6172, 1.6422]]) tensor([[1.6172, 1.6422]])
""" """
def __init__(self, weight=1., box_format='xyxy'): def __init__(self, weight=1., box_format='xyxy'):
self.weight = weight self.weight = weight
assert box_format in ['xyxy', 'xywh'] assert box_format in ['xyxy', 'xywh']
self.box_format = box_format self.box_format = box_format
def __call__(self, bbox_pred, gt_bboxes): def __call__(self, bbox_pred, gt_bboxes):
""" """
Args: Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates bbox_pred (Tensor): Predicted boxes with normalized coordinates
(cx, cy, w, h), which are all in range [0, 1]. Shape (cx, cy, w, h), which are all in range [0, 1]. Shape
[num_query, 4]. [num_query, 4].
gt_bboxes (Tensor): Ground truth boxes with normalized gt_bboxes (Tensor): Ground truth boxes with normalized
coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
Returns: Returns:
torch.Tensor: bbox_cost value with weight torch.Tensor: bbox_cost value with weight
""" """
# if self.box_format == 'xywh': # if self.box_format == 'xywh':
# gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes) # gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
# elif self.box_format == 'xyxy': # elif self.box_format == 'xyxy':
# bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred) # bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1) bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
return bbox_cost * self.weight return bbox_cost * self.weight
@MATCH_COST.register_module() @MATCH_COST.register_module()
class IoUCostC: class IoUCostC:
"""IoUCost. """IoUCost.
Args: Args:
iou_mode (str, optional): iou mode such as 'iou' | 'giou' iou_mode (str, optional): iou mode such as 'iou' | 'giou'
weight (int | float, optional): loss weight weight (int | float, optional): loss weight
Examples: Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import IoUCost >>> from mmdet.core.bbox.match_costs.match_cost import IoUCost
>>> import torch >>> import torch
>>> self = IoUCost() >>> self = IoUCost()
>>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]]) >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]])
>>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
>>> self(bboxes, gt_bboxes) >>> self(bboxes, gt_bboxes)
tensor([[-0.1250, 0.1667], tensor([[-0.1250, 0.1667],
[ 0.1667, -0.5000]]) [ 0.1667, -0.5000]])
""" """
def __init__(self, iou_mode='giou', weight=1., box_format='xywh'): def __init__(self, iou_mode='giou', weight=1., box_format='xywh'):
self.weight = weight self.weight = weight
self.iou_mode = iou_mode self.iou_mode = iou_mode
assert box_format in ['xyxy', 'xywh'] assert box_format in ['xyxy', 'xywh']
self.box_format = box_format self.box_format = box_format
def __call__(self, bboxes, gt_bboxes): def __call__(self, bboxes, gt_bboxes):
""" """
Args: Args:
bboxes (Tensor): Predicted boxes with unnormalized coordinates bboxes (Tensor): Predicted boxes with unnormalized coordinates
(x1, y1, x2, y2). Shape [num_query, 4]. (x1, y1, x2, y2). Shape [num_query, 4].
gt_bboxes (Tensor): Ground truth boxes with unnormalized gt_bboxes (Tensor): Ground truth boxes with unnormalized
coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
Returns: Returns:
torch.Tensor: iou_cost value with weight torch.Tensor: iou_cost value with weight
""" """
if self.box_format == 'xywh': if self.box_format == 'xywh':
bboxes = bbox_cxcywh_to_xyxy(bboxes) bboxes = bbox_cxcywh_to_xyxy(bboxes)
gt_bboxes = bbox_cxcywh_to_xyxy(gt_bboxes) gt_bboxes = bbox_cxcywh_to_xyxy(gt_bboxes)
# overlaps: [num_bboxes, num_gt] # overlaps: [num_bboxes, num_gt]
overlaps = bbox_overlaps( overlaps = bbox_overlaps(
bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False) bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False)
# The 1 is a constant that doesn't change the matching, so omitted. # The 1 is a constant that doesn't change the matching, so omitted.
iou_cost = -overlaps iou_cost = -overlaps
return iou_cost * self.weight return iou_cost * self.weight
@MATCH_COST.register_module() @MATCH_COST.register_module()
class DynamicLinesCost(object): class DynamicLinesCost(object):
"""LinesL1Cost. """LinesL1Cost.
Args: Args:
weight (int | float, optional): loss_weight weight (int | float, optional): loss_weight
""" """
def __init__(self, weight=1.): def __init__(self, weight=1.):
self.weight = weight self.weight = weight
def __call__(self, lines_pred, lines_gt, masks_pred, masks_gt): def __call__(self, lines_pred, lines_gt, masks_pred, masks_gt):
""" """
Args: Args:
lines_pred (Tensor): predicted normalized lines: lines_pred (Tensor): predicted normalized lines:
[nP, num_points, 2] [nP, num_points, 2]
lines_gt (Tensor): Ground truth lines lines_gt (Tensor): Ground truth lines
[nG, num_points, 2] [nG, num_points, 2]
masks_pred: [nP, num_points] masks_pred: [nP, num_points]
masks_gt: [nG, num_points] masks_gt: [nG, num_points]
Returns: Returns:
dist_mat: reg_cost value with weight dist_mat: reg_cost value with weight
shape [nP, nG] shape [nP, nG]
""" """
dist_mat = self.cal_dist(lines_pred, lines_gt) dist_mat = self.cal_dist(lines_pred, lines_gt)
dist_mat = self.get_dynamic_line(dist_mat, masks_pred, masks_gt) dist_mat = self.get_dynamic_line(dist_mat, masks_pred, masks_gt)
dist_mat = dist_mat * self.weight dist_mat = dist_mat * self.weight
return dist_mat return dist_mat
def cal_dist(self, x1, x2): def cal_dist(self, x1, x2):
''' '''
Args: Args:
x1: B1,N,2 x1: B1,N,2
x2: B2,N,2 x2: B2,N,2
Return: Return:
dist_mat: B1,B2,N dist_mat: B1,B2,N
''' '''
x1 = x1.permute(1, 0, 2) x1 = x1.permute(1, 0, 2)
x2 = x2.permute(1, 0, 2) x2 = x2.permute(1, 0, 2)
dist_mat = torch.cdist(x1, x2, p=2) dist_mat = torch.cdist(x1, x2, p=2)
dist_mat = dist_mat.permute(1, 2, 0) dist_mat = dist_mat.permute(1, 2, 0)
return dist_mat return dist_mat
def get_dynamic_line(self, mat, m1, m2): def get_dynamic_line(self, mat, m1, m2):
''' '''
get dynamic line with difference approach get dynamic line with difference approach
mat: N1xN2xnpts mat: N1xN2xnpts
m1: N1xnpts m1: N1xnpts
m2: N2xnpts m2: N2xnpts
''' '''
# nPxnGxnum_points # nPxnGxnum_points
m1 = m1.unsqueeze(1).sigmoid() > 0.5 m1 = m1.unsqueeze(1).sigmoid() > 0.5
m2 = m2.unsqueeze(0) m2 = m2.unsqueeze(0)
valid_points_mask = (m1 + m2)/2. valid_points_mask = (m1 + m2)/2.
average_factor_mask = valid_points_mask.sum(-1) > 0 average_factor_mask = valid_points_mask.sum(-1) > 0
average_factor = average_factor_mask.masked_fill( average_factor = average_factor_mask.masked_fill(
~average_factor_mask, 1) ~average_factor_mask, 1)
# takes the average # takes the average
mat = mat * valid_points_mask mat = mat * valid_points_mask
mat = mat.sum(-1) / average_factor mat = mat.sum(-1) / average_factor
return mat return mat
@MATCH_COST.register_module() @MATCH_COST.register_module()
class BBoxLogitsCost(object): class BBoxLogitsCost(object):
"""BBoxLogits. """BBoxLogits.
Args: Args:
weight (int | float, optional): loss_weight weight (int | float, optional): loss_weight
""" """
def __init__(self, weight=1.): def __init__(self, weight=1.):
self.weight = weight self.weight = weight
def calNLL(self, logits, value): def calNLL(self, logits, value):
''' '''
Args: Args:
logits: B1, 8, cls_dim logits: B1, 8, cls_dim
value: B2, 8, value: B2, 8,
Return: Return:
log_likelihood: B1,B2,8 log_likelihood: B1,B2,8
''' '''
logits = logits[:, None] logits = logits[:, None]
value = value[None] value = value[None]
value = value.long().unsqueeze(-1) value = value.long().unsqueeze(-1)
value, log_pmf = torch.broadcast_tensors(value, logits) value, log_pmf = torch.broadcast_tensors(value, logits)
value = value[..., :1] value = value[..., :1]
return log_pmf.gather(-1, value).squeeze(-1) return log_pmf.gather(-1, value).squeeze(-1)
def __call__(self, bbox_pred, bbox_gt, **kwargs): def __call__(self, bbox_pred, bbox_gt, **kwargs):
""" """
Args: Args:
bbox_pred: nproposal, 4*2, pos_dim bbox_pred: nproposal, 4*2, pos_dim
bbox_gt: ngt, 4*2 bbox_gt: ngt, 4*2
Returns: Returns:
cost: nproposal, ngt cost: nproposal, ngt
""" """
cost = self.calNLL(bbox_pred, bbox_gt).mean(-1) cost = self.calNLL(bbox_pred, bbox_gt).mean(-1)
return cost * self.weight return cost * self.weight
@MATCH_COST.register_module() @MATCH_COST.register_module()
class MapQueriesCost(object): class MapQueriesCost(object):
def __init__(self, cls_cost, reg_cost, iou_cost=None): def __init__(self, cls_cost, reg_cost, iou_cost=None):
self.cls_cost = build_match_cost(cls_cost) self.cls_cost = build_match_cost(cls_cost)
self.reg_cost = build_match_cost(reg_cost) self.reg_cost = build_match_cost(reg_cost)
self.iou_cost = None self.iou_cost = None
if iou_cost is not None: if iou_cost is not None:
self.iou_cost = build_match_cost(iou_cost) self.iou_cost = build_match_cost(iou_cost)
def __call__(self, preds: dict, gts: dict): def __call__(self, preds: dict, gts: dict):
# classification and bboxcost. # classification and bboxcost.
cls_cost = self.cls_cost(preds['scores'], gts['labels']) cls_cost = self.cls_cost(preds['scores'], gts['labels'])
# regression cost # regression cost
regkwargs = {} regkwargs = {}
if 'masks' in preds and 'masks' in gts: if 'masks' in preds and 'masks' in gts:
assert isinstance(self.reg_cost, DynamicLinesCost), ' Issues!!' assert isinstance(self.reg_cost, DynamicLinesCost), ' Issues!!'
regkwargs = { regkwargs = {
'masks_pred': preds['masks'], 'masks_pred': preds['masks'],
'masks_gt': gts['masks'], 'masks_gt': gts['masks'],
} }
reg_cost = self.reg_cost(preds['lines'], gts['lines'], **regkwargs) reg_cost = self.reg_cost(preds['lines'], gts['lines'], **regkwargs)
# weighted sum of above three costs # weighted sum of above three costs
cost = cls_cost + reg_cost cost = cls_cost + reg_cost
# Iou # Iou
if self.iou_cost is not None: if self.iou_cost is not None:
iou_cost = self.iou_cost(preds['lines'],gts['lines']) iou_cost = self.iou_cost(preds['lines'],gts['lines'])
cost += iou_cost cost += iou_cost
return cost return cost
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class NoiseSythesis(nn.Module): class NoiseSythesis(nn.Module):
def __init__(self, def __init__(self,
p, scale=0.01, shift_scale=(8,5), p, scale=0.01, shift_scale=(8,5),
scaling_size=(0.1,0.1), canvas_size=(200, 100), scaling_size=(0.1,0.1), canvas_size=(200, 100),
bbox_type='sce', bbox_type='sce',
poly_coord_dim=2, poly_coord_dim=2,
bbox_coord_dim=2, bbox_coord_dim=2,
quantify=True): quantify=True):
super(NoiseSythesis, self).__init__() super(NoiseSythesis, self).__init__()
self.p = p self.p = p
self.scale = scale self.scale = scale
self.bbox_type = bbox_type self.bbox_type = bbox_type
self.quantify = quantify self.quantify = quantify
self.poly_coord_dim = poly_coord_dim self.poly_coord_dim = poly_coord_dim
self.bbox_coord_dim = bbox_coord_dim self.bbox_coord_dim = bbox_coord_dim
self.transforms = [self.random_shifting, self.random_scaling] self.transforms = [self.random_shifting, self.random_scaling]
# self.transforms = [self.random_scaling] # self.transforms = [self.random_scaling]
self.register_buffer('canvas_size', torch.tensor(canvas_size)) self.register_buffer('canvas_size', torch.tensor(canvas_size))
self.register_buffer('shift_scale', torch.tensor(shift_scale).float()) self.register_buffer('shift_scale', torch.tensor(shift_scale).float())
self.register_buffer('scaling_size', torch.tensor(scaling_size)) self.register_buffer('scaling_size', torch.tensor(scaling_size))
def random_scaling(self, bbox): def random_scaling(self, bbox):
''' '''
bbox: B, paramter_num, 2 bbox: B, paramter_num, 2
''' '''
device = bbox.device device = bbox.device
dtype = bbox.dtype dtype = bbox.dtype
B = bbox.shape[0] B = bbox.shape[0]
noise = (torch.rand(B, device=device)*2-1)[:,None,None] # [-1,1] noise = (torch.rand(B, device=device)*2-1)[:,None,None] # [-1,1]
scale = self.scaling_size.to(device) scale = self.scaling_size.to(device)
scale = (noise * scale) + 1 scale = (noise * scale) + 1
scaled_bbox = bbox * scale scaled_bbox = bbox * scale
# recenterization # recenterization
coffset = scaled_bbox.mean(-2) - bbox.float().mean(-2) coffset = scaled_bbox.mean(-2) - bbox.float().mean(-2)
scaled_bbox = scaled_bbox - coffset[:,None] scaled_bbox = scaled_bbox - coffset[:,None]
return scaled_bbox.round().type(dtype) return scaled_bbox.round().type(dtype)
def random_shifting(self, bbox): def random_shifting(self, bbox):
''' '''
bbox: B, paramter_num, 2 bbox: B, paramter_num, 2
''' '''
device = bbox.device device = bbox.device
batch_size = bbox.shape[0] batch_size = bbox.shape[0]
shift_scale = self.shift_scale shift_scale = self.shift_scale
scale = (bbox.max(1)[0] - bbox.min(1)[0]) * 0.1 scale = (bbox.max(1)[0] - bbox.min(1)[0]) * 0.1
scale = torch.where(scale < shift_scale, scale, shift_scale) scale = torch.where(scale < shift_scale, scale, shift_scale)
noise = (torch.rand(batch_size, 2, device=device)*2-1) # [-1,1] noise = (torch.rand(batch_size, 2, device=device)*2-1) # [-1,1]
offset = (noise * scale).round().type(bbox.dtype) offset = (noise * scale).round().type(bbox.dtype)
shifted_bbox = bbox + offset[:, None] shifted_bbox = bbox + offset[:, None]
return shifted_bbox return shifted_bbox
def gaussian_noise_bbox(self, bbox): def gaussian_noise_bbox(self, bbox):
dtype = bbox.dtype dtype = bbox.dtype
batch_size = bbox.shape[0] batch_size = bbox.shape[0]
scale = (self.canvas_size * self.scale)[:self.bbox_coord_dim] scale = (self.canvas_size * self.scale)[:self.bbox_coord_dim]
noisy_bbox = torch.normal(bbox.type(torch.float), scale) noisy_bbox = torch.normal(bbox.type(torch.float), scale)
if self.quantify: if self.quantify:
noisy_bbox = noisy_bbox.round().type(dtype) noisy_bbox = noisy_bbox.round().type(dtype)
# prevent out of bound case # prevent out of bound case
for i in range(self.bbox_coord_dim): for i in range(self.bbox_coord_dim):
noisy_bbox[...,i] =\ noisy_bbox[...,i] =\
torch.clamp(noisy_bbox[...,0],1,self.canvas_size[i]) torch.clamp(noisy_bbox[...,0],1,self.canvas_size[i])
else: else:
noisy_bbox = noisy_bbox.type(torch.float) noisy_bbox = noisy_bbox.type(torch.float)
return noisy_bbox return noisy_bbox
def gaussian_noise_poly(self, polyline, polyline_mask): def gaussian_noise_poly(self, polyline, polyline_mask):
device = polyline.device device = polyline.device
batchsize = polyline.shape[0] batchsize = polyline.shape[0]
scale = self.canvas_size * self.scale scale = self.canvas_size * self.scale
polyline = F.pad(polyline,(0,self.poly_coord_dim-1)) polyline = F.pad(polyline,(0,self.poly_coord_dim-1))
polyline = polyline.view(batchsize,-1, self.poly_coord_dim) polyline = polyline.view(batchsize,-1, self.poly_coord_dim)
mask = F.pad(polyline_mask[:,1:],(0,self.poly_coord_dim)) mask = F.pad(polyline_mask[:,1:],(0,self.poly_coord_dim))
noisy_polyline = torch.normal(polyline.type(torch.float), scale) noisy_polyline = torch.normal(polyline.type(torch.float), scale)
if self.quantify: if self.quantify:
noisy_polyline = noisy_polyline.round().type(polyline.dtype) noisy_polyline = noisy_polyline.round().type(polyline.dtype)
# prevent out of bound case # prevent out of bound case
for i in range(self.poly_coord_dim): for i in range(self.poly_coord_dim):
noisy_polyline[...,i] =\ noisy_polyline[...,i] =\
torch.clamp(noisy_polyline[...,i],0,self.canvas_size[i]) torch.clamp(noisy_polyline[...,i],0,self.canvas_size[i])
else: else:
noisy_polyline = noisy_polyline.type(torch.float) noisy_polyline = noisy_polyline.type(torch.float)
noisy_polyline = noisy_polyline.view(batchsize,-1) * mask noisy_polyline = noisy_polyline.view(batchsize,-1) * mask
noisy_polyline = noisy_polyline[:,:-(self.poly_coord_dim-1)] noisy_polyline = noisy_polyline[:,:-(self.poly_coord_dim-1)]
return noisy_polyline return noisy_polyline
def random_apply(self, bbox): def random_apply(self, bbox):
for t in self.transforms: for t in self.transforms:
if self.p < torch.rand(1): if self.p < torch.rand(1):
continue continue
bbox = t(bbox) bbox = t(bbox)
# prevent out of bound case # prevent out of bound case
bbox[...,0] =\ bbox[...,0] =\
torch.clamp(bbox[...,0],0,self.canvas_size[0]) torch.clamp(bbox[...,0],0,self.canvas_size[0])
bbox[...,1] =\ bbox[...,1] =\
torch.clamp(bbox[...,1],0,self.canvas_size[1]) torch.clamp(bbox[...,1],0,self.canvas_size[1])
return bbox return bbox
def simple_aug(self, batch): def simple_aug(self, batch):
# augment bbox # augment bbox
if self.bbox_type in ['sce', 'xyxy']: if self.bbox_type in ['sce', 'xyxy']:
fbbox = batch['bbox_flat'] fbbox = batch['bbox_flat']
seq_len = fbbox.shape[0] seq_len = fbbox.shape[0]
bbox = fbbox.view(seq_len, -1, 2) bbox = fbbox.view(seq_len, -1, 2)
bbox = self.gaussian_noise_bbox(bbox) bbox = self.gaussian_noise_bbox(bbox)
fbbox_aug = bbox.view(seq_len, -1) fbbox_aug = bbox.view(seq_len, -1)
aug_mask = torch.rand(fbbox.shape,device=fbbox.device) aug_mask = torch.rand(fbbox.shape,device=fbbox.device)
fbbox = torch.where(aug_mask<self.p, fbbox_aug, fbbox) fbbox = torch.where(aug_mask<self.p, fbbox_aug, fbbox)
elif self.bbox_type == 'rxyxy': elif self.bbox_type == 'rxyxy':
fbbox = self.rbbox_aug(batch) fbbox = self.rbbox_aug(batch)
elif self.bbox_type == 'convex_hull': elif self.bbox_type == 'convex_hull':
fbbox = self.convex_hull_aug(batch) fbbox = self.convex_hull_aug(batch)
# augment # augment
polyline = batch['polylines'] polyline = batch['polylines']
polyline_mask = batch['polyline_masks'] polyline_mask = batch['polyline_masks']
polyline_aug = self.gaussian_noise_poly(polyline, polyline_mask) polyline_aug = self.gaussian_noise_poly(polyline, polyline_mask)
aug_mask = torch.rand(polyline.shape,device=polyline.device) aug_mask = torch.rand(polyline.shape,device=polyline.device)
polyline = torch.where(aug_mask<self.p, polyline_aug, polyline) polyline = torch.where(aug_mask<self.p, polyline_aug, polyline)
return polyline, fbbox return polyline, fbbox
def rbbox_aug(self, batch): def rbbox_aug(self, batch):
return None return None
def convex_hull_aug(self,batch): def convex_hull_aug(self,batch):
return None return None
def __call__(self, batch, simple_aug=False): def __call__(self, batch, simple_aug=False):
if simple_aug: if simple_aug:
return self.simple_aug(batch) return self.simple_aug(batch)
else: else:
fbbox = batch['bbox_flat'] fbbox = batch['bbox_flat']
seq_len = fbbox.shape[0] seq_len = fbbox.shape[0]
bbox = fbbox.view(seq_len, -1, self.bbox_coord_dim) bbox = fbbox.view(seq_len, -1, self.bbox_coord_dim)
aug_bbox = self.random_apply(bbox) aug_bbox = self.random_apply(bbox)
aug_bbox_flat = aug_bbox.view(seq_len, -1) aug_bbox_flat = aug_bbox.view(seq_len, -1)
return aug_bbox_flat return aug_bbox_flat
import copy import copy
import math import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmdet3d.models.builder import BACKBONES from mmdet3d.models.builder import BACKBONES
from mmdet.models import build_backbone, build_neck from mmdet.models import build_backbone, build_neck
class UpsampleBlock(nn.Module): class UpsampleBlock(nn.Module):
def __init__(self, ins, outs): def __init__(self, ins, outs):
super(UpsampleBlock, self).__init__() super(UpsampleBlock, self).__init__()
self.gn = nn.GroupNorm(32, outs) self.gn = nn.GroupNorm(32, outs)
self.conv = nn.Conv2d(ins, outs, kernel_size=3, self.conv = nn.Conv2d(ins, outs, kernel_size=3,
stride=1, padding=1) # same stride=1, padding=1) # same
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = self.relu(self.gn(x)) x = self.relu(self.gn(x))
x = self.upsample2x(x) x = self.upsample2x(x)
return x return x
def upsample2x(self, x): def upsample2x(self, x):
_, _, h, w = x.shape _, _, h, w = x.shape
x = F.interpolate(x, size=(h*2, w*2), x = F.interpolate(x, size=(h*2, w*2),
mode='bilinear', align_corners=True) mode='bilinear', align_corners=True)
return x return x
class Upsample(nn.Module): class Upsample(nn.Module):
def __init__(self, def __init__(self,
zoom_size=(2, 4, 8), zoom_size=(2, 4, 8),
in_channels=128, in_channels=128,
out_channels=128, out_channels=128,
): ):
super(Upsample, self).__init__() super(Upsample, self).__init__()
self.out_channels = out_channels self.out_channels = out_channels
input_conv = UpsampleBlock(in_channels, out_channels) input_conv = UpsampleBlock(in_channels, out_channels)
inter_conv = UpsampleBlock(out_channels, out_channels) inter_conv = UpsampleBlock(out_channels, out_channels)
fscale = [] fscale = []
for scale_factor in zoom_size: for scale_factor in zoom_size:
layer_num = int(math.log2(scale_factor)) layer_num = int(math.log2(scale_factor))
if layer_num < 1: if layer_num < 1:
fscale.append(nn.Identity()) fscale.append(nn.Identity())
continue continue
tmp = [copy.deepcopy(input_conv), ] tmp = [copy.deepcopy(input_conv), ]
tmp += [copy.deepcopy(inter_conv) for i in range(layer_num-1)] tmp += [copy.deepcopy(inter_conv) for i in range(layer_num-1)]
fscale.append(nn.Sequential(*tmp)) fscale.append(nn.Sequential(*tmp))
self.fscale = nn.ModuleList(fscale) self.fscale = nn.ModuleList(fscale)
def init_weights(self): def init_weights(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, a=1) nn.init.kaiming_uniform_(m.weight, a=1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, imgs): def forward(self, imgs):
rescale_i = [] rescale_i = []
for f, img in zip(self.fscale, imgs): for f, img in zip(self.fscale, imgs):
rescale_i.append(f(img)) rescale_i.append(f(img))
out = sum(rescale_i) out = sum(rescale_i)
return out return out
@BACKBONES.register_module() @BACKBONES.register_module()
class IPMEncoder(nn.Module): class IPMEncoder(nn.Module):
''' '''
encode cam features encode cam features
''' '''
def __init__(self, def __init__(self,
img_backbone, img_backbone,
img_neck, img_neck,
upsample, upsample,
xbound=[-30.0, 30.0, 0.5], xbound=[-30.0, 30.0, 0.5],
ybound=[-15.0, 15.0, 0.5], ybound=[-15.0, 15.0, 0.5],
zbound=[-10.0, 10.0, 20.0], zbound=[-10.0, 10.0, 20.0],
heights=[-1.1, 0, 0.5, 1.1], heights=[-1.1, 0, 0.5, 1.1],
pretrained=None, pretrained=None,
out_channels=128, out_channels=128,
num_cam=6, num_cam=6,
use_lidar=False, use_lidar=False,
use_image=True, use_image=True,
lidar_dim=128, lidar_dim=128,
): ):
super(IPMEncoder, self).__init__() super(IPMEncoder, self).__init__()
self.x_bound = xbound self.x_bound = xbound
self.y_bound = ybound self.y_bound = ybound
self.heights = heights self.heights = heights
self.num_cam = num_cam self.num_cam = num_cam
num_x = int((xbound[1] - xbound[0]) / xbound[2]) num_x = int((xbound[1] - xbound[0]) / xbound[2])
num_y = int((ybound[1] - ybound[0]) / ybound[2]) num_y = int((ybound[1] - ybound[0]) / ybound[2])
self.img_backbone = build_backbone(img_backbone) self.img_backbone = build_backbone(img_backbone)
self.img_neck = build_neck(img_neck) self.img_neck = build_neck(img_neck)
self.upsample = Upsample(**upsample) self.upsample = Upsample(**upsample)
self.use_image = use_image self.use_image = use_image
self.use_lidar = use_lidar self.use_lidar = use_lidar
if self.use_lidar: if self.use_lidar:
self.pp = PointPillarEncoder(lidar_dim, xbound, ybound, zbound) self.pp = PointPillarEncoder(lidar_dim, xbound, ybound, zbound)
self.outconvs =\ self.outconvs =\
nn.Conv2d((self.upsample.out_channels+3)*len(heights), out_channels//2, nn.Conv2d((self.upsample.out_channels+3)*len(heights), out_channels//2,
kernel_size=3, stride=1, padding=1) # same kernel_size=3, stride=1, padding=1) # same
if self.use_image: if self.use_image:
_out_channels = out_channels//2 _out_channels = out_channels//2
else: else:
_out_channels = out_channels _out_channels = out_channels
self.outconvs_lidar =\ self.outconvs_lidar =\
nn.Conv2d(lidar_dim, _out_channels, nn.Conv2d(lidar_dim, _out_channels,
kernel_size=3, stride=1, padding=1) # same kernel_size=3, stride=1, padding=1) # same
else: else:
self.outconvs =\ self.outconvs =\
nn.Conv2d((self.upsample.out_channels+3)*len(heights), out_channels, nn.Conv2d((self.upsample.out_channels+3)*len(heights), out_channels,
kernel_size=3, stride=1, padding=1) # same kernel_size=3, stride=1, padding=1) # same
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
# bev_plane # bev_plane
bev_planes = [construct_plane_grid( bev_planes = [construct_plane_grid(
xbound, ybound, h) for h in self.heights] xbound, ybound, h) for h in self.heights]
self.register_buffer('bev_planes', torch.stack( self.register_buffer('bev_planes', torch.stack(
bev_planes),) # nlvl,bH,bW,2 bev_planes),) # nlvl,bH,bW,2
self.masked_embeds = nn.Embedding(len(heights), out_channels) self.masked_embeds = nn.Embedding(len(heights), out_channels)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
"""Initialize model weights.""" """Initialize model weights."""
self.img_backbone.init_weights() self.img_backbone.init_weights()
self.img_neck.init_weights() self.img_neck.init_weights()
self.upsample.init_weights() self.upsample.init_weights()
for p in self.outconvs.parameters(): for p in self.outconvs.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
if self.use_lidar: if self.use_lidar:
for p in self.outconvs_lidar.parameters(): for p in self.outconvs_lidar.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
for p in self.pp.parameters(): for p in self.pp.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
def extract_img_feat(self, imgs): def extract_img_feat(self, imgs):
''' '''
Extract image feaftures and sum up into one pic Extract image feaftures and sum up into one pic
Args: Args:
imgs: B, n_cam, C, iH, iW imgs: B, n_cam, C, iH, iW
Returns: Returns:
img_feat: B * n_cam, C, H, W img_feat: B * n_cam, C, H, W
''' '''
B, n_cam, C, iH, iW = imgs.shape B, n_cam, C, iH, iW = imgs.shape
imgs = imgs.view(B * n_cam, C, iH, iW) imgs = imgs.view(B * n_cam, C, iH, iW)
img_feats = self.img_backbone(imgs) img_feats = self.img_backbone(imgs)
# reduce the channel dim # reduce the channel dim
img_feats = self.img_neck(img_feats) img_feats = self.img_neck(img_feats)
# fuse four feature map # fuse four feature map
img_feat = self.upsample(img_feats) img_feat = self.upsample(img_feats)
return img_feat return img_feat
def forward(self, imgs, img_metas, *args, points=None, **kwargs): def forward(self, imgs, img_metas, *args, points=None, **kwargs):
''' '''
Args: Args:
imgs: torch.Tensor of shape [B, N, 3, H, W] imgs: torch.Tensor of shape [B, N, 3, H, W]
N: number of cams N: number of cams
img_metas: img_metas:
# N=6, ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'] # N=6, ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
ego2cam: [B, N, 4, 4] ego2cam: [B, N, 4, 4]
cam_intrinsics: [B, N, 3, 3] cam_intrinsics: [B, N, 3, 3]
cam2ego_rotations: [B, N, 3, 3] cam2ego_rotations: [B, N, 3, 3]
cam2ego_translations: [B, N, 3] cam2ego_translations: [B, N, 3]
... ...
Outs: Outs:
bev_feature: torch.Tensor of shape [B, C*nlvl, bH, bW] bev_feature: torch.Tensor of shape [B, C*nlvl, bH, bW]
''' '''
if self.use_image: if self.use_image:
self.B = imgs.shape[0] self.B = imgs.shape[0]
# Get transform matrix # Get transform matrix
ego2cam = [] ego2cam = []
for img_meta in img_metas: for img_meta in img_metas:
ego2cam.append(img_meta['ego2img']) ego2cam.append(img_meta['ego2img'])
img_shape = imgs.shape[-2:] img_shape = imgs.shape[-2:]
ego2cam = np.asarray(ego2cam) ego2cam = np.asarray(ego2cam)
# Image backbone # Image backbone
img_feats = self.extract_img_feat(imgs) img_feats = self.extract_img_feat(imgs)
# IPM # IPM
bev_feat, bev_feat_mask = self.ipm(img_feats, ego2cam, img_shape) bev_feat, bev_feat_mask = self.ipm(img_feats, ego2cam, img_shape)
# multi level into a same # multi level into a same
bev_feat = bev_feat.flatten(1, 2) bev_feat = bev_feat.flatten(1, 2)
bev_feat = self.outconvs(bev_feat) bev_feat = self.outconvs(bev_feat)
if self.use_lidar: if self.use_lidar:
lidar_feat = self.get_lidar_feature(points) lidar_feat = self.get_lidar_feature(points)
if self.use_image: if self.use_image:
bev_feat = torch.cat([bev_feat,lidar_feat],dim=1) bev_feat = torch.cat([bev_feat,lidar_feat],dim=1)
else: else:
bev_feat = lidar_feat bev_feat = lidar_feat
return bev_feat return bev_feat
def ipm(self, cam_feat, ego2cam, img_shape): def ipm(self, cam_feat, ego2cam, img_shape):
''' '''
inverse project inverse project
Args: Args:
cam_feat: B*ncam, C, cH, cW cam_feat: B*ncam, C, cH, cW
img_shape: tuple(H, W) img_shape: tuple(H, W)
Returns: Returns:
project_feat: B, C, nlvl, bH, bW project_feat: B, C, nlvl, bH, bW
bev_feat_mask: B, 1, nlvl, bH, bW bev_feat_mask: B, 1, nlvl, bH, bW
''' '''
C = cam_feat.shape[1] C = cam_feat.shape[1]
bev_grid = self.bev_planes.unsqueeze(0).repeat(self.B, 1, 1, 1, 1) bev_grid = self.bev_planes.unsqueeze(0).repeat(self.B, 1, 1, 1, 1)
nlvl, bH, bW = bev_grid.shape[1:4] nlvl, bH, bW = bev_grid.shape[1:4]
bev_grid = bev_grid.flatten(1, 3) # B, nlvl*W*H, 3 bev_grid = bev_grid.flatten(1, 3) # B, nlvl*W*H, 3
# Find points in cam coords # Find points in cam coords
# bev_grid_pos: B*ncam, nlvl*bH*bW, 2 # bev_grid_pos: B*ncam, nlvl*bH*bW, 2
bev_grid_pos, bev_cam_mask = get_campos(bev_grid, ego2cam, img_shape) bev_grid_pos, bev_cam_mask = get_campos(bev_grid, ego2cam, img_shape)
# B*cam, nlvl*bH, bW, 2 # B*cam, nlvl*bH, bW, 2
bev_grid_pos = bev_grid_pos.unflatten(-2, (nlvl*bH, bW)) bev_grid_pos = bev_grid_pos.unflatten(-2, (nlvl*bH, bW))
# project feat from 2D to bev plane # project feat from 2D to bev plane
projected_feature = F.grid_sample( projected_feature = F.grid_sample(
cam_feat, bev_grid_pos, align_corners=False).view(self.B, -1, C, nlvl, bH, bW) # B,cam,C,nlvl,bH,bW cam_feat, bev_grid_pos, align_corners=False).view(self.B, -1, C, nlvl, bH, bW) # B,cam,C,nlvl,bH,bW
# B,cam,nlvl,bH,bW # B,cam,nlvl,bH,bW
bev_feat_mask = bev_cam_mask.unflatten(-1, (nlvl, bH, bW)) bev_feat_mask = bev_cam_mask.unflatten(-1, (nlvl, bH, bW))
# eliminate the ncam # eliminate the ncam
# The bev feature is the sum of the 6 cameras # The bev feature is the sum of the 6 cameras
bev_feat_mask = bev_feat_mask.unsqueeze(2) bev_feat_mask = bev_feat_mask.unsqueeze(2)
projected_feature = (projected_feature*bev_feat_mask).sum(1) projected_feature = (projected_feature*bev_feat_mask).sum(1)
num_feat = bev_feat_mask.sum(1) num_feat = bev_feat_mask.sum(1)
projected_feature = projected_feature / \ projected_feature = projected_feature / \
num_feat.masked_fill(num_feat == 0, 1) num_feat.masked_fill(num_feat == 0, 1)
# concatenate a position information # concatenate a position information
# projected_feature: B, bH, bW, nlvl, C+3 # projected_feature: B, bH, bW, nlvl, C+3
bev_grid = bev_grid.view(self.B, nlvl, bH, bW, bev_grid = bev_grid.view(self.B, nlvl, bH, bW,
3).permute(0, 4, 1, 2, 3) 3).permute(0, 4, 1, 2, 3)
projected_feature = torch.cat( projected_feature = torch.cat(
(projected_feature, bev_grid), dim=1) (projected_feature, bev_grid), dim=1)
return projected_feature, bev_feat_mask.sum(1) > 0 return projected_feature, bev_feat_mask.sum(1) > 0
def get_lidar_feature(self, points): def get_lidar_feature(self, points):
ptensor, pmask = points ptensor, pmask = points
lidar_feature = self.pp(ptensor, pmask) lidar_feature = self.pp(ptensor, pmask)
# bev_grid = self.bev_planes[...,:-1].unsqueeze(0).repeat(self.B, 1, 1, 1, 1) # bev_grid = self.bev_planes[...,:-1].unsqueeze(0).repeat(self.B, 1, 1, 1, 1)
# bev_grid = bev_grid[:,0] # bev_grid = bev_grid[:,0]
# bev_grid = bev_grid.permute(0, 3, 1, 2) # bev_grid = bev_grid.permute(0, 3, 1, 2)
# lidar_feature = torch.cat( # lidar_feature = torch.cat(
# (lidar_feature, bev_grid), dim=1) # (lidar_feature, bev_grid), dim=1)
lidar_feature = self.outconvs_lidar(lidar_feature) lidar_feature = self.outconvs_lidar(lidar_feature)
return lidar_feature return lidar_feature
def construct_plane_grid(xbound, ybound, height: float, dtype=torch.float32): def construct_plane_grid(xbound, ybound, height: float, dtype=torch.float32):
''' '''
Returns: Returns:
plane: H, W, 3 plane: H, W, 3
''' '''
xmin, xmax = xbound[0], xbound[1] xmin, xmax = xbound[0], xbound[1]
num_x = int((xbound[1] - xbound[0]) / xbound[2]) num_x = int((xbound[1] - xbound[0]) / xbound[2])
ymin, ymax = ybound[0], ybound[1] ymin, ymax = ybound[0], ybound[1]
num_y = int((ybound[1] - ybound[0]) / ybound[2]) num_y = int((ybound[1] - ybound[0]) / ybound[2])
x = torch.linspace(xmin, xmax, num_x, dtype=dtype) x = torch.linspace(xmin, xmax, num_x, dtype=dtype)
y = torch.linspace(ymin, ymax, num_y, dtype=dtype) y = torch.linspace(ymin, ymax, num_y, dtype=dtype)
# [num_y, num_x] # [num_y, num_x]
y, x = torch.meshgrid(y, x) y, x = torch.meshgrid(y, x)
z = torch.ones_like(x) * height z = torch.ones_like(x) * height
# [num_y, num_x, 3] # [num_y, num_x, 3]
plane = torch.stack([x, y, z], dim=-1) plane = torch.stack([x, y, z], dim=-1)
return plane return plane
def get_campos(reference_points, ego2cam, img_shape): def get_campos(reference_points, ego2cam, img_shape):
''' '''
Find the each refence point's corresponding pixel in each camera Find the each refence point's corresponding pixel in each camera
Args: Args:
reference_points: [B, num_query, 3] reference_points: [B, num_query, 3]
ego2cam: (B, num_cam, 4, 4) ego2cam: (B, num_cam, 4, 4)
Outs: Outs:
reference_points_cam: (B*num_cam, num_query, 2) reference_points_cam: (B*num_cam, num_query, 2)
mask: (B, num_cam, num_query) mask: (B, num_cam, num_query)
num_query == W*H num_query == W*H
''' '''
ego2cam = reference_points.new_tensor(ego2cam) # (B, N, 4, 4) ego2cam = reference_points.new_tensor(ego2cam) # (B, N, 4, 4)
reference_points = reference_points.clone() reference_points = reference_points.clone()
B, num_query = reference_points.shape[:2] B, num_query = reference_points.shape[:2]
num_cam = ego2cam.shape[1] num_cam = ego2cam.shape[1]
# reference_points (B, num_queries, 4) # reference_points (B, num_queries, 4)
reference_points = torch.cat( reference_points = torch.cat(
(reference_points, torch.ones_like(reference_points[..., :1])), -1) (reference_points, torch.ones_like(reference_points[..., :1])), -1)
reference_points = reference_points.view( reference_points = reference_points.view(
B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1) B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
ego2cam = ego2cam.view( ego2cam = ego2cam.view(
B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1) B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1)
# reference_points_cam (B, num_cam, num_queries, 4) # reference_points_cam (B, num_cam, num_queries, 4)
reference_points_cam = (ego2cam @ reference_points).squeeze(-1) reference_points_cam = (ego2cam @ reference_points).squeeze(-1)
eps = 1e-9 eps = 1e-9
mask = (reference_points_cam[..., 2:3] > eps) mask = (reference_points_cam[..., 2:3] > eps)
reference_points_cam =\ reference_points_cam =\
reference_points_cam[..., 0:2] / \ reference_points_cam[..., 0:2] / \
reference_points_cam[..., 2:3] + eps reference_points_cam[..., 2:3] + eps
reference_points_cam[..., 0] /= img_shape[1] reference_points_cam[..., 0] /= img_shape[1]
reference_points_cam[..., 1] /= img_shape[0] reference_points_cam[..., 1] /= img_shape[0]
# from 0~1 to -1~1 # from 0~1 to -1~1
reference_points_cam = (reference_points_cam - 0.5) * 2 reference_points_cam = (reference_points_cam - 0.5) * 2
mask = (mask & (reference_points_cam[..., 0:1] > -1.0) mask = (mask & (reference_points_cam[..., 0:1] > -1.0)
& (reference_points_cam[..., 0:1] < 1.0) & (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 1:2] > -1.0) & (reference_points_cam[..., 1:2] > -1.0)
& (reference_points_cam[..., 1:2] < 1.0)) & (reference_points_cam[..., 1:2] < 1.0))
# (B, num_cam, num_query) # (B, num_cam, num_query)
mask = mask.view(B, num_cam, num_query) mask = mask.view(B, num_cam, num_query)
reference_points_cam = reference_points_cam.view(B*num_cam, num_query, 2) reference_points_cam = reference_points_cam.view(B*num_cam, num_query, 2)
return reference_points_cam, mask return reference_points_cam, mask
def _test(): def _test():
pass pass
if __name__ == '__main__': if __name__ == '__main__':
_test() _test()
from .base_map_head import BaseMapHead from .base_map_head import BaseMapHead
from .dg_head import DGHead from .dg_head import DGHead
from .map_element_detector import MapElementDetector from .map_element_detector import MapElementDetector
from .polyline_generator import PolylineGenerator from .polyline_generator import PolylineGenerator
\ No newline at end of file
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch.nn as nn import torch.nn as nn
from mmcv.runner import auto_fp16 from mmcv.runner import auto_fp16
from mmcv.utils import print_log from mmcv.utils import print_log
from mmdet.utils import get_root_logger from mmdet.utils import get_root_logger
class BaseMapHead(nn.Module, metaclass=ABCMeta): class BaseMapHead(nn.Module, metaclass=ABCMeta):
"""Base class for mappers.""" """Base class for mappers."""
def __init__(self): def __init__(self):
super(BaseMapHead, self).__init__() super(BaseMapHead, self).__init__()
self.fp16_enabled = False self.fp16_enabled = False
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
"""Initialize the weights in detector. """Initialize the weights in detector.
Args: Args:
pretrained (str, optional): Path to pre-trained weights. pretrained (str, optional): Path to pre-trained weights.
Defaults to None. Defaults to None.
""" """
if pretrained is not None: if pretrained is not None:
logger = get_root_logger() logger = get_root_logger()
print_log(f'load model from: {pretrained}', logger=logger) print_log(f'load model from: {pretrained}', logger=logger)
@auto_fp16(apply_to=('img', )) @auto_fp16(apply_to=('img', ))
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
pass pass
@abstractmethod @abstractmethod
def loss(self, pred, gt): def loss(self, pred, gt):
''' '''
Compute loss Compute loss
Output: Output:
dict( dict(
loss: torch.Tensor loss: torch.Tensor
log_vars: dict( log_vars: dict(
str: float, str: float,
) )
num_samples: int num_samples: int
) )
''' '''
return return
@abstractmethod @abstractmethod
def post_process(self, pred): def post_process(self, pred):
''' '''
convert model predictions to vectorized outputs convert model predictions to vectorized outputs
the output format should be consistent with the evaluation function the output format should be consistent with the evaluation function
''' '''
return return
# the causal layer is credited by the https://github.com/alexmt-scale/causal-transformer-decoder # the causal layer is credited by the https://github.com/alexmt-scale/causal-transformer-decoder
# we made some change to stick with the polygen. # we made some change to stick with the polygen.
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional from typing import Optional
from torch import Tensor from torch import Tensor
from mmcv.cnn.bricks.registry import ATTENTION from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
def build_attention(cfg, default_args=None): def build_attention(cfg, default_args=None):
"""Builder for attention.""" """Builder for attention."""
return build_from_cfg(cfg, ATTENTION, default_args) return build_from_cfg(cfg, ATTENTION, default_args)
class CausalTransformerDecoder(nn.TransformerDecoder): class CausalTransformerDecoder(nn.TransformerDecoder):
"""Implementation of a transformer decoder based on torch implementation but """Implementation of a transformer decoder based on torch implementation but
more efficient. The difference is that it doesn't need to recompute the more efficient. The difference is that it doesn't need to recompute the
embeddings of all the past decoded tokens but instead uses a cache to embeddings of all the past decoded tokens but instead uses a cache to
store them. This makes use of the fact that the attention of a decoder is store them. This makes use of the fact that the attention of a decoder is
causal, so new predicted tokens don't affect the old tokens' embedding bc causal, so new predicted tokens don't affect the old tokens' embedding bc
the corresponding attention cells are masked. the corresponding attention cells are masked.
The complexity goes from seq_len^3 to seq_len^2. The complexity goes from seq_len^3 to seq_len^2.
This only happens in eval mode. This only happens in eval mode.
In training mode, teacher forcing makes these optimizations unnecessary. Hence the In training mode, teacher forcing makes these optimizations unnecessary. Hence the
Decoder acts like a regular nn.TransformerDecoder (except that the attention tgt Decoder acts like a regular nn.TransformerDecoder (except that the attention tgt
masks are handled for you). masks are handled for you).
""" """
def forward( def forward(
self, self,
tgt: Tensor, tgt: Tensor,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
cache: Optional[Tensor] = None, cache: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
causal_mask: Optional[Tensor] = None, causal_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
""" """
Args: Args:
tgt (Tensor): current_len_output x bsz x hidden_dim tgt (Tensor): current_len_output x bsz x hidden_dim
memory (Tensor): len_encoded_seq x bsz x hidden_dim memory (Tensor): len_encoded_seq x bsz x hidden_dim
cache (Optional[Tensor]): cache (Optional[Tensor]):
n_layers x (current_len_output - 1) x bsz x hidden_dim n_layers x (current_len_output - 1) x bsz x hidden_dim
If current_len_output == 1, nothing is cached yet, so cache If current_len_output == 1, nothing is cached yet, so cache
should be None. Same if the module is in training mode. should be None. Same if the module is in training mode.
others (Optional[Tensor]): see official documentations others (Optional[Tensor]): see official documentations
Returns: Returns:
output (Tensor): current_len_output x bsz x hidden_dim output (Tensor): current_len_output x bsz x hidden_dim
cache (Optional[Tensor]): n_layers x current_len_output x bsz x hidden_dim cache (Optional[Tensor]): n_layers x current_len_output x bsz x hidden_dim
Only returns it when module is in eval mode (no caching in training) Only returns it when module is in eval mode (no caching in training)
""" """
output = tgt output = tgt
if self.training: if self.training:
if cache is not None: if cache is not None:
raise ValueError( raise ValueError(
"cache parameter should be None in training mode") "cache parameter should be None in training mode")
for mod in self.layers: for mod in self.layers:
output = mod( output = mod(
output, output,
memory, memory,
memory_mask=memory_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
causal_mask=causal_mask, causal_mask=causal_mask,
only_last=False, only_last=False,
) )
return output, cache return output, cache
else: else:
new_token_cache = [] new_token_cache = []
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
output = mod(output, memory, output = mod(output, memory,
memory_mask=memory_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
causal_mask=causal_mask, causal_mask=causal_mask,
only_last=True if cache is not None else False) only_last=True if cache is not None else False)
new_token_cache.append(output) new_token_cache.append(output)
# use the pre_calculated intermediate parameters. # use the pre_calculated intermediate parameters.
if cache is not None: if cache is not None:
output = torch.cat([cache[i], output], dim=0) output = torch.cat([cache[i], output], dim=0)
if cache is not None: if cache is not None:
new_cache = torch.cat( new_cache = torch.cat(
[cache, torch.stack(new_token_cache, dim=0)], dim=1) [cache, torch.stack(new_token_cache, dim=0)], dim=1)
else: else:
new_cache = torch.stack(new_token_cache, dim=0) new_cache = torch.stack(new_token_cache, dim=0)
return output, new_cache return output, new_cache
class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer): class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer):
def __init__(self, *args, re_zero=True, norm_first=True, map_attn_cfg=None, **kwargs): def __init__(self, *args, re_zero=True, norm_first=True, map_attn_cfg=None, **kwargs):
''' '''
Args: Args:
re_zero: If True, alpha scale residuals with zero init. re_zero: If True, alpha scale residuals with zero init.
''' '''
super(CausalTransformerDecoderLayer, self).__init__(*args, **kwargs) super(CausalTransformerDecoderLayer, self).__init__(*args, **kwargs)
if re_zero: if re_zero:
self.res_weight1 = nn.Parameter(torch.FloatTensor([0, ])) self.res_weight1 = nn.Parameter(torch.FloatTensor([0, ]))
self.res_weight2 = nn.Parameter(torch.FloatTensor([0, ])) self.res_weight2 = nn.Parameter(torch.FloatTensor([0, ]))
self.res_weight3 = nn.Parameter(torch.FloatTensor([0, ])) self.res_weight3 = nn.Parameter(torch.FloatTensor([0, ]))
else: else:
self.res_weight1 = 1. self.res_weight1 = 1.
self.res_weight2 = 1. self.res_weight2 = 1.
self.res_weight3 = 1. self.res_weight3 = 1.
self.norm_first = norm_first self.norm_first = norm_first
self.map_attn = None self.map_attn = None
if map_attn_cfg is not None: if map_attn_cfg is not None:
self.map_attn = build_attention(map_attn_cfg) self.map_attn = build_attention(map_attn_cfg)
def forward( def forward(
self, self,
tgt: Tensor, tgt: Tensor,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
causal_mask: Optional[Tensor] = None, causal_mask: Optional[Tensor] = None,
query: Optional[Tensor] = None, query: Optional[Tensor] = None,
only_last=False) -> Tensor: only_last=False) -> Tensor:
""" """
Args: Args:
see CausalTransformerDecoder see CausalTransformerDecoder
query is not None model will perform query stream query is not None model will perform query stream
Returns: Returns:
Tensor: Tensor:
If training: embedding of the whole layer: seq_len x bsz x hidden_dim If training: embedding of the whole layer: seq_len x bsz x hidden_dim
If eval mode: embedding of last token: 1 x bsz x hidden_dim If eval mode: embedding of last token: 1 x bsz x hidden_dim
""" """
if not self.norm_first: if not self.norm_first:
raise ValueError( raise ValueError(
"norm_first parameter should be True!") "norm_first parameter should be True!")
if self.training: if self.training:
# the official Pytorch implementation # the official Pytorch implementation
x = tgt x = tgt
if query is not None: if query is not None:
x = query x = query
x = x + self.res_weight1 * \ x = x + self.res_weight1 * \
self._sa_block(self.norm1(x), self.norm1(tgt), causal_mask, self._sa_block(self.norm1(x), self.norm1(tgt), causal_mask,
tgt_key_padding_mask) tgt_key_padding_mask)
if memory is not None: if memory is not None:
x = x + self.res_weight2 * \ x = x + self.res_weight2 * \
self._mha_block(self.norm2(x), memory, self._mha_block(self.norm2(x), memory,
memory_mask, memory_key_padding_mask) memory_mask, memory_key_padding_mask)
x = x + self.res_weight3*self._ff_block(self.norm3(x)) x = x + self.res_weight3*self._ff_block(self.norm3(x))
return x return x
# This part is adapted from the official Pytorch implementation # This part is adapted from the official Pytorch implementation
# So that only the last token gets modified and returned. # So that only the last token gets modified and returned.
# we follow the pre-LN trans in https://arxiv.org/pdf/2002.04745v1.pdf . # we follow the pre-LN trans in https://arxiv.org/pdf/2002.04745v1.pdf .
x = tgt x = tgt
if query is not None: if query is not None:
x = query x = query
if only_last: if only_last:
x = x[-1:] x = x[-1:]
if causal_mask is not None: if causal_mask is not None:
attn_mask = causal_mask attn_mask = causal_mask
if only_last: if only_last:
attn_mask = attn_mask[-1:] # XXX attn_mask = attn_mask[-1:] # XXX
else: else:
attn_mask = None attn_mask = None
# efficient self attention # efficient self attention
x = x + self.res_weight1 * \ x = x + self.res_weight1 * \
self._sa_block(self.norm1(x), self.norm1(tgt), attn_mask, self._sa_block(self.norm1(x), self.norm1(tgt), attn_mask,
tgt_key_padding_mask) tgt_key_padding_mask)
# encoder-decoder attention # encoder-decoder attention
if memory is not None: if memory is not None:
x = x + self.res_weight2 * \ x = x + self.res_weight2 * \
self._mha_block(self.norm2(x), memory, self._mha_block(self.norm2(x), memory,
memory_mask, memory_key_padding_mask) memory_mask, memory_key_padding_mask)
# final feed-forward network # final feed-forward network
x = x + self.res_weight3*self._ff_block(self.norm3(x)) x = x + self.res_weight3*self._ff_block(self.norm3(x))
return x return x
# self-attention block # self-attention block
def _sa_block(self, x: Tensor, mem: Tensor, def _sa_block(self, x: Tensor, mem: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
x = self.self_attn(x, mem, mem, x = self.self_attn(x, mem, mem,
attn_mask=attn_mask, attn_mask=attn_mask,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=False)[0] need_weights=False)[0]
return self.dropout1(x) return self.dropout1(x)
# multihead attention block # multihead attention block
def _mha_block(self, x: Tensor, mem: Tensor, def _mha_block(self, x: Tensor, mem: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
x = self.multihead_attn(x, mem, mem, x = self.multihead_attn(x, mem, mem,
attn_mask=attn_mask, attn_mask=attn_mask,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=False)[0] need_weights=False)[0]
return self.dropout2(x) return self.dropout2(x)
# feed forward block # feed forward block
def _ff_block(self, x: Tensor) -> Tensor: def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x) return self.dropout3(x)
class PolygenTransformerEncoderLayer(nn.TransformerEncoderLayer): class PolygenTransformerEncoderLayer(nn.TransformerEncoderLayer):
def __init__(self, *args, re_zero=True, norm_first=True, **kwargs): def __init__(self, *args, re_zero=True, norm_first=True, **kwargs):
''' '''
Args: Args:
re_zero: If True, alpha scale residuals with zero init. re_zero: If True, alpha scale residuals with zero init.
''' '''
super(PolygenTransformerEncoderLayer, self).__init__(*args, **kwargs) super(PolygenTransformerEncoderLayer, self).__init__(*args, **kwargs)
if re_zero: if re_zero:
self.res_weight1 = nn.Parameter(torch.FloatTensor([0, ])) self.res_weight1 = nn.Parameter(torch.FloatTensor([0, ]))
self.res_weight2 = nn.Parameter(torch.FloatTensor([0, ])) self.res_weight2 = nn.Parameter(torch.FloatTensor([0, ]))
else: else:
self.res_weight1 = 1. self.res_weight1 = 1.
self.res_weight2 = 1. self.res_weight2 = 1.
self.norm_first = norm_first self.norm_first = norm_first
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
r"""Pass the input through the encoder layer. r"""Pass the input through the encoder layer.
Args: Args:
src: the sequence to the encoder layer (required). src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
Shape: Shape:
see the docs in Transformer class. see the docs in Transformer class.
""" """
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
x = src x = src
if self.norm_first: if self.norm_first:
x = x + self.res_weight1*self._sa_block(self.norm1(x), src_mask, x = x + self.res_weight1*self._sa_block(self.norm1(x), src_mask,
src_key_padding_mask) src_key_padding_mask)
x = x + self.res_weight2*self._ff_block(self.norm2(x)) x = x + self.res_weight2*self._ff_block(self.norm2(x))
else: else:
x = self.norm1( x = self.norm1(
x + self.res_weight1*self._sa_block(x, src_mask, src_key_padding_mask)) x + self.res_weight1*self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self.res_weight2*self._ff_block(x)) x = self.norm2(x + self.res_weight2*self._ff_block(x))
return x return x
# self-attention block # self-attention block
def _sa_block(self, x: Tensor, def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
x = self.self_attn(x, x, x, x = self.self_attn(x, x, x,
attn_mask=attn_mask, attn_mask=attn_mask,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=False)[0] need_weights=False)[0]
return self.dropout1(x) return self.dropout1(x)
# feed forward block # feed forward block
def _ff_block(self, x: Tensor) -> Tensor: def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x) return self.dropout2(x)
def generate_square_subsequent_mask(sz: int, device: str = "cpu") -> torch.Tensor: def generate_square_subsequent_mask(sz: int, device: str = "cpu") -> torch.Tensor:
""" Generate the attention mask for causal decoding """ """ Generate the attention mask for causal decoding """
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = ( mask = (
mask.float() mask.float()
.masked_fill(mask == 0, float("-inf")) .masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0)) .masked_fill(mask == 1, float(0.0))
).to(device=device) ).to(device=device)
return mask return mask
\ No newline at end of file
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
def generate_square_subsequent_mask(sz: int, condition_len: int = 1, bool_out=False, device: str = "cpu") -> torch.Tensor: def generate_square_subsequent_mask(sz: int, condition_len: int = 1, bool_out=False, device: str = "cpu") -> torch.Tensor:
""" Generate the attention mask for causal decoding """ """ Generate the attention mask for causal decoding """
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
if condition_len > 1: if condition_len > 1:
mask[:condition_len,:condition_len] = 1 mask[:condition_len,:condition_len] = 1
if not bool_out: if not bool_out:
mask = ( mask = (
mask.float() mask.float()
.masked_fill(mask == 0, float("-inf")) .masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))) .masked_fill(mask == 1, float(0.0)))
return mask.to(device=device) return mask.to(device=device)
def dequantize_verts(verts, canvas_size: Tensor, add_noise=False): def dequantize_verts(verts, canvas_size: Tensor, add_noise=False):
"""Quantizes vertices and outputs integers with specified n_bits.""" """Quantizes vertices and outputs integers with specified n_bits."""
min_range = -1 min_range = -1
max_range = 1 max_range = 1
range_quantize = canvas_size range_quantize = canvas_size
verts = verts.type(torch.float32) verts = verts.type(torch.float32)
verts = verts * (max_range - min_range) / range_quantize + min_range verts = verts * (max_range - min_range) / range_quantize + min_range
if add_noise: if add_noise:
verts += torch.rand_like(verts) * range_quantize verts += torch.rand_like(verts) * range_quantize
return verts return verts
def quantize_verts( def quantize_verts(
verts, verts,
canvas_size: Tensor): canvas_size: Tensor):
"""Convert vertices from its original range ([-1,1]) to discrete values in [0, n_bits**2 - 1]. """Convert vertices from its original range ([-1,1]) to discrete values in [0, n_bits**2 - 1].
Args: Args:
verts: seqlen, 2 verts: seqlen, 2
""" """
min_range = -1 min_range = -1
max_range = 1 max_range = 1
range_quantize = canvas_size-1 range_quantize = canvas_size-1
verts_ratio = (verts - min_range) / ( verts_ratio = (verts - min_range) / (
max_range - min_range) max_range - min_range)
verts_quantize = verts_ratio * range_quantize verts_quantize = verts_ratio * range_quantize
return verts_quantize.type(torch.int32) return verts_quantize.type(torch.int32)
def top_k_logits(logits, k): def top_k_logits(logits, k):
"""Masks logits such that logits not in top-k are small.""" """Masks logits such that logits not in top-k are small."""
if k == 0: if k == 0:
return logits return logits
else: else:
values, _ = torch.topk(logits, k=k) values, _ = torch.topk(logits, k=k)
k_largest = torch.min(values) k_largest = torch.min(values)
logits = torch.where(logits < k_largest, logits = torch.where(logits < k_largest,
torch.ones_like(logits)*-1e9, logits) torch.ones_like(logits)*-1e9, logits)
return logits return logits
def top_p_logits(logits, p): def top_p_logits(logits, p):
"""Masks logits using nucleus (top-p) sampling.""" """Masks logits using nucleus (top-p) sampling."""
if p == 1: if p == 1:
return logits return logits
else: else:
seq, dim = logits.shape[1:] seq, dim = logits.shape[1:]
logits = logits.view(-1, dim) logits = logits.view(-1, dim)
sort_indices = torch.argsort(logits, dim=-1, descending=True) sort_indices = torch.argsort(logits, dim=-1, descending=True)
probs = F.softmax(logits, dim=-1).gather(-1, sort_indices) probs = F.softmax(logits, dim=-1).gather(-1, sort_indices)
cumprobs = torch.cumsum(probs, dim=-1) - probs cumprobs = torch.cumsum(probs, dim=-1) - probs
# The top 1 candidate always will not be masked. # The top 1 candidate always will not be masked.
# This way ensures at least 1 indices will be selected. # This way ensures at least 1 indices will be selected.
sort_mask = (cumprobs > p).type(logits.dtype) sort_mask = (cumprobs > p).type(logits.dtype)
batch_indices = torch.repeat_interleave( batch_indices = torch.repeat_interleave(
torch.arange(logits.shape[0]).unsqueeze(-1), dim, dim=-1) torch.arange(logits.shape[0]).unsqueeze(-1), dim, dim=-1)
top_p_mask = torch.zeros_like(logits) top_p_mask = torch.zeros_like(logits)
top_p_mask = top_p_mask.scatter_add(-1, sort_indices, sort_mask) top_p_mask = top_p_mask.scatter_add(-1, sort_indices, sort_mask)
logits -= top_p_mask * 1e9 logits -= top_p_mask * 1e9
return logits.view(-1, seq, dim) return logits.view(-1, seq, dim)
import copy import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import Conv2d, Linear from mmcv.cnn import Conv2d, Linear
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from torch.distributions.categorical import Categorical from torch.distributions.categorical import Categorical
from mmdet.core import multi_apply, reduce_mean from mmdet.core import multi_apply, reduce_mean
from mmdet.models import HEADS from mmdet.models import HEADS
from .detr_head import DETRMapFixedNumHead from .detr_head import DETRMapFixedNumHead
@HEADS.register_module(force=True) @HEADS.register_module(force=True)
class DETRBboxHead(DETRMapFixedNumHead): class DETRBboxHead(DETRMapFixedNumHead):
def __init__(self, *args, canvas_size=(400, 200), discrete_output=True, separate_detect=True, def __init__(self, *args, canvas_size=(400, 200), discrete_output=True, separate_detect=True,
mode='xyxy', bbox_size=None, coord_dim=2, kp_coord_dim=2, mode='xyxy', bbox_size=None, coord_dim=2, kp_coord_dim=2,
**kwargs): **kwargs):
self.canvas_size = canvas_size # hard code self.canvas_size = canvas_size # hard code
self.separate_detect = separate_detect self.separate_detect = separate_detect
self.discrete_output = discrete_output self.discrete_output = discrete_output
self.bbox_size = 3 if mode=='sce' else 2 self.bbox_size = 3 if mode=='sce' else 2
if bbox_size is not None: if bbox_size is not None:
self.bbox_size = bbox_size self.bbox_size = bbox_size
self.coord_dim = coord_dim # for xyz self.coord_dim = coord_dim # for xyz
self.kp_coord_dim = kp_coord_dim self.kp_coord_dim = kp_coord_dim
super(DETRBboxHead, self).__init__(*args, **kwargs) super(DETRBboxHead, self).__init__(*args, **kwargs)
del self.canvas_size del self.canvas_size
self.register_buffer('canvas_size', torch.tensor(canvas_size)) self.register_buffer('canvas_size', torch.tensor(canvas_size))
self._init_embedding() self._init_embedding()
def _init_embedding(self): def _init_embedding(self):
# for bbox parameter xstart, ystart, xend, yend # for bbox parameter xstart, ystart, xend, yend
self.bbox_embedding = nn.Embedding(4, self.embed_dims) self.bbox_embedding = nn.Embedding(4, self.embed_dims)
self.label_embed = nn.Embedding( self.label_embed = nn.Embedding(
self.num_classes, self.embed_dims) self.num_classes, self.embed_dims)
self.img_coord_embed = nn.Linear(2, self.embed_dims) self.img_coord_embed = nn.Linear(2, self.embed_dims)
def _init_branch(self,): def _init_branch(self,):
"""Initialize classification branch and regression branch of head.""" """Initialize classification branch and regression branch of head."""
# add sigmoid or not # add sigmoid or not
if self.separate_detect: if self.separate_detect:
if self.cls_out_channels == self.num_classes+1: if self.cls_out_channels == self.num_classes+1:
self.cls_out_channels = 2 self.cls_out_channels = 2
else: else:
self.cls_out_channels = 1 self.cls_out_channels = 1
fc_cls = Linear(self.embed_dims, self.cls_out_channels) fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch = [] reg_branch = []
for _ in range(self.num_reg_fcs): for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims)) reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.LayerNorm(self.embed_dims)) reg_branch.append(nn.LayerNorm(self.embed_dims))
reg_branch.append(nn.ReLU()) reg_branch.append(nn.ReLU())
if self.discrete_output: if self.discrete_output:
reg_branch.append(nn.Linear( reg_branch.append(nn.Linear(
self.embed_dims, max(self.canvas_size), bias=True,)) self.embed_dims, max(self.canvas_size), bias=True,))
else: else:
reg_branch.append(nn.Linear( reg_branch.append(nn.Linear(
self.embed_dims, self.bbox_size*self.coord_dim, bias=True,)) self.embed_dims, self.bbox_size*self.coord_dim, bias=True,))
reg_branch = nn.Sequential(*reg_branch) reg_branch = nn.Sequential(*reg_branch)
def _get_clones(module, N): def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
num_pred = self.transformer.decoder.num_layers num_pred = self.transformer.decoder.num_layers
if self.iterative: if self.iterative:
fc_cls = _get_clones(fc_cls, num_pred) fc_cls = _get_clones(fc_cls, num_pred)
reg_branch = _get_clones(reg_branch, num_pred) reg_branch = _get_clones(reg_branch, num_pred)
self.pre_branches = nn.ModuleDict([ self.pre_branches = nn.ModuleDict([
('cls', fc_cls), ('cls', fc_cls),
('reg', reg_branch), ]) ('reg', reg_branch), ])
def _prepare_context(self, batch, context): def _prepare_context(self, batch, context):
"""Prepare class label and vertex context.""" """Prepare class label and vertex context."""
global_context_embedding = None global_context_embedding = None
if self.separate_detect: if self.separate_detect:
global_context_embedding = self.label_embed(batch['class_label']) global_context_embedding = self.label_embed(batch['class_label'])
# Image context # Image context
if self.separate_detect: if self.separate_detect:
image_embeddings = assign_bev( image_embeddings = assign_bev(
context['bev_embeddings'], batch['batch_idx']) context['bev_embeddings'], batch['batch_idx'])
else: else:
image_embeddings = context['bev_embeddings'] image_embeddings = context['bev_embeddings']
image_embeddings = self.input_proj( image_embeddings = self.input_proj(
image_embeddings) # only change feature size image_embeddings) # only change feature size
# Pass images through encoder # Pass images through encoder
device = image_embeddings.device device = image_embeddings.device
# Add 2D coordinate grid embedding # Add 2D coordinate grid embedding
B, C, H, W = image_embeddings.shape B, C, H, W = image_embeddings.shape
Ws = torch.linspace(-1., 1., W) Ws = torch.linspace(-1., 1., W)
Hs = torch.linspace(-1., 1., H) Hs = torch.linspace(-1., 1., H)
image_coords = torch.stack( image_coords = torch.stack(
torch.meshgrid(Hs, Ws), dim=-1).to(device) torch.meshgrid(Hs, Ws), dim=-1).to(device)
image_coord_embeddings = self.img_coord_embed(image_coords) image_coord_embeddings = self.img_coord_embed(image_coords)
image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2) image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2)
# Reshape spatial grid to sequence # Reshape spatial grid to sequence
sequential_context_embeddings = image_embeddings.reshape( sequential_context_embeddings = image_embeddings.reshape(
B, C, H, W) B, C, H, W)
return (global_context_embedding, sequential_context_embeddings) return (global_context_embedding, sequential_context_embeddings)
def forward(self, batch, context, img_metas=None): def forward(self, batch, context, img_metas=None):
''' '''
Args: Args:
bev_feature (List[Tensor]): shape [B, C, H, W] bev_feature (List[Tensor]): shape [B, C, H, W]
feature in bev view feature in bev view
img_metas img_metas
Outs: Outs:
preds_dict (Dict): preds_dict (Dict):
all_cls_scores (Tensor): Classification score of all all_cls_scores (Tensor): Classification score of all
decoder layers, has shape decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels]. [nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor): all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2]. [nb_dec, bs, num_query, num_points, 2].
''' '''
(global_context_embedding, sequential_context_embeddings) =\ (global_context_embedding, sequential_context_embeddings) =\
self._prepare_context(batch, context) self._prepare_context(batch, context)
if self.separate_detect: if self.separate_detect:
query_embedding = self.query_embedding.weight[None] + \ query_embedding = self.query_embedding.weight[None] + \
global_context_embedding[:, None] global_context_embedding[:, None]
else: else:
B = sequential_context_embeddings.shape[0] B = sequential_context_embeddings.shape[0]
query_embedding = self.query_embedding.weight[None].repeat(B, 1, 1) query_embedding = self.query_embedding.weight[None].repeat(B, 1, 1)
x = sequential_context_embeddings x = sequential_context_embeddings
B, C, H, W = x.shape B, C, H, W = x.shape
masks = x.new_zeros((B, H, W)) masks = x.new_zeros((B, H, W))
pos_embed = self.positional_encoding(masks) pos_embed = self.positional_encoding(masks)
# outs_dec: [nb_dec, bs, num_query, embed_dim] # outs_dec: [nb_dec, bs, num_query, embed_dim]
outs_dec, _ = self.transformer(x, masks.type(torch.bool), query_embedding, outs_dec, _ = self.transformer(x, masks.type(torch.bool), query_embedding,
pos_embed) pos_embed)
outputs = [] outputs = []
for i, query_feat in enumerate(outs_dec): for i, query_feat in enumerate(outs_dec):
outputs.append(self.get_prediction(query_feat)) outputs.append(self.get_prediction(query_feat))
return outputs return outputs
def get_prediction(self, query_feat): def get_prediction(self, query_feat):
ocls = self.pre_branches['cls'](query_feat) ocls = self.pre_branches['cls'](query_feat)
if self.discrete_output: if self.discrete_output:
pos = [] pos = []
for i in range(4): for i in range(4):
pos_embeds = self.bbox_embedding.weight[i] pos_embeds = self.bbox_embedding.weight[i]
_pos = self.pre_branches['reg'](query_feat+pos_embeds) _pos = self.pre_branches['reg'](query_feat+pos_embeds)
pos.append(_pos) pos.append(_pos)
# # y mask # # y mask
# _vert_mask = torch.arange(logits.shape[-1], device=logits.device) # _vert_mask = torch.arange(logits.shape[-1], device=logits.device)
# vertices_mask_y = (_vert_mask < self.canvas_size[1]+1) # vertices_mask_y = (_vert_mask < self.canvas_size[1]+1)
# logits[:,1::2] = logits[:,1::2]*vertices_mask_y - ~vertices_mask_y*1e9 # logits[:,1::2] = logits[:,1::2]*vertices_mask_y - ~vertices_mask_y*1e9
logits = torch.stack(pos, dim=-2)/1. logits = torch.stack(pos, dim=-2)/1.
lines = Categorical(logits=logits) lines = Categorical(logits=logits)
else: else:
lines = self.pre_branches['reg'](query_feat).sigmoid() lines = self.pre_branches['reg'](query_feat).sigmoid()
lines = lines.unflatten(-1, (self.bbox_size, self.coord_dim))*self.canvas_size lines = lines.unflatten(-1, (self.bbox_size, self.coord_dim))*self.canvas_size
lines = lines.flatten(-2) lines = lines.flatten(-2)
return dict( return dict(
lines=lines, # [bs, num_query, 4, num_canvas_size] lines=lines, # [bs, num_query, 4, num_canvas_size]
scores=ocls, # [bs, num_query, num_class] scores=ocls, # [bs, num_query, num_class]
) )
@force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines')) @force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines'))
def _get_target_single(self, def _get_target_single(self,
score_pred, score_pred,
lines_pred, lines_pred,
gt_labels, gt_labels,
gt_lines, gt_lines,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
""" """
Compute regression and classification targets for one image. Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used. Outputs from a single decoder layer of a single feature level are used.
Args: Args:
cls_score (Tensor): Box score logits from a single decoder layer cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels]. for one image. Shape [num_query, cls_out_channels].
lines_pred (Tensor): lines_pred (Tensor):
shape [num_query, num_points, 2]. shape [num_query, num_points, 2].
gt_lines (Tensor): gt_lines (Tensor):
shape [num_gt, num_points, 2]. shape [num_gt, num_points, 2].
gt_labels (torch.LongTensor) gt_labels (torch.LongTensor)
shape [num_gt, ] shape [num_gt, ]
Returns: Returns:
tuple[Tensor]: a tuple containing the following for one image. tuple[Tensor]: a tuple containing the following for one image.
- labels (LongTensor): Labels of each image. - labels (LongTensor): Labels of each image.
shape [num_query, 1] shape [num_query, 1]
- label_weights (Tensor]): Label weights of each image. - label_weights (Tensor]): Label weights of each image.
shape [num_query, 1] shape [num_query, 1]
- lines_target (Tensor): Lines targets of each image. - lines_target (Tensor): Lines targets of each image.
shape [num_query, num_points, 2] shape [num_query, num_points, 2]
- lines_weights (Tensor): Lines weights of each image. - lines_weights (Tensor): Lines weights of each image.
shape [num_query, num_points, 2] shape [num_query, num_points, 2]
- pos_inds (Tensor): Sampled positive indices for each image. - pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image. - neg_inds (Tensor): Sampled negative indices for each image.
""" """
num_pred_lines = len(lines_pred) num_pred_lines = len(lines_pred)
# assigner and sampler # assigner and sampler
assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,), assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,),
gts=dict(lines=gt_lines, gts=dict(lines=gt_lines,
labels=gt_labels, ), labels=gt_labels, ),
gt_bboxes_ignore=gt_bboxes_ignore) gt_bboxes_ignore=gt_bboxes_ignore)
sampling_result = self.sampler.sample( sampling_result = self.sampler.sample(
assign_result, lines_pred, gt_lines) assign_result, lines_pred, gt_lines)
pos_inds = sampling_result.pos_inds pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds neg_inds = sampling_result.neg_inds
pos_gt_inds = sampling_result.pos_assigned_gt_inds pos_gt_inds = sampling_result.pos_assigned_gt_inds
# label targets 0: foreground, 1: background # label targets 0: foreground, 1: background
if self.separate_detect: if self.separate_detect:
labels = gt_lines.new_full((num_pred_lines, ), 1, dtype=torch.long) labels = gt_lines.new_full((num_pred_lines, ), 1, dtype=torch.long)
else: else:
labels = gt_lines.new_full( labels = gt_lines.new_full(
(num_pred_lines, ), self.num_classes, dtype=torch.long) (num_pred_lines, ), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_lines.new_ones(num_pred_lines) label_weights = gt_lines.new_ones(num_pred_lines)
# bbox targets since lines_pred's last dimension is the vocabulary # bbox targets since lines_pred's last dimension is the vocabulary
# and ground truth dose not have this dimension. # and ground truth dose not have this dimension.
if self.discrete_output: if self.discrete_output:
lines_target = torch.zeros_like(lines_pred[..., 0]).long() lines_target = torch.zeros_like(lines_pred[..., 0]).long()
lines_weights = torch.zeros_like(lines_pred[..., 0]) lines_weights = torch.zeros_like(lines_pred[..., 0])
else: else:
lines_target = torch.zeros_like(lines_pred) lines_target = torch.zeros_like(lines_pred)
lines_weights = torch.zeros_like(lines_pred) lines_weights = torch.zeros_like(lines_pred)
lines_target[pos_inds] = sampling_result.pos_gt_bboxes.type( lines_target[pos_inds] = sampling_result.pos_gt_bboxes.type(
lines_target.dtype) lines_target.dtype)
lines_weights[pos_inds] = 1.0 lines_weights[pos_inds] = 1.0
n = lines_weights.sum(-1, keepdim=True) n = lines_weights.sum(-1, keepdim=True)
lines_weights = lines_weights / n.masked_fill(n == 0, 1) lines_weights = lines_weights / n.masked_fill(n == 0, 1)
return (labels, label_weights, lines_target, lines_weights, return (labels, label_weights, lines_target, lines_weights,
pos_inds, neg_inds, pos_gt_inds) pos_inds, neg_inds, pos_gt_inds)
# @force_fp32(apply_to=('preds', 'gts')) # @force_fp32(apply_to=('preds', 'gts'))
def get_targets(self, preds, gts, gt_bboxes_ignore_list=None): def get_targets(self, preds, gts, gt_bboxes_ignore_list=None):
""" """
Compute regression and classification targets for a batch image. Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used. Outputs from a single decoder layer of a single feature level are used.
Args: Args:
cls_scores_list (list[Tensor]): Box score logits from a single cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query, decoder layer for each image with shape [num_query,
cls_out_channels]. cls_out_channels].
lines_preds_list (list[Tensor]): [num_query, num_points, 2]. lines_preds_list (list[Tensor]): [num_query, num_points, 2].
gt_lines_list (list[Tensor]): Ground truth lines for each image gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2) with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ). image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None. boxes which can be ignored for each image. Default None.
Returns: Returns:
tuple: a tuple containing the following targets. tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images. - labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \ - label_weights_list (list[Tensor]): Label weights for all \
images. images.
- lines_targets_list (list[Tensor]): Lines targets for all \ - lines_targets_list (list[Tensor]): Lines targets for all \
images. images.
- lines_weight_list (list[Tensor]): Lines weights for all \ - lines_weight_list (list[Tensor]): Lines weights for all \
images. images.
- num_total_pos (int): Number of positive samples in all \ - num_total_pos (int): Number of positive samples in all \
images. images.
- num_total_neg (int): Number of negative samples in all \ - num_total_neg (int): Number of negative samples in all \
images. images.
""" """
assert gt_bboxes_ignore_list is None, \ assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.' 'Only supports for gt_bboxes_ignore setting to None.'
# format the inputs # format the inputs
if self.separate_detect: if self.separate_detect:
bbox = [b[m] for b, m in zip(gts['bbox'], gts['bbox_mask'])] bbox = [b[m] for b, m in zip(gts['bbox'], gts['bbox_mask'])]
class_label = torch.zeros_like(gts['bbox_mask']).long() class_label = torch.zeros_like(gts['bbox_mask']).long()
class_label = [b[m] for b, m in zip(class_label, gts['bbox_mask'])] class_label = [b[m] for b, m in zip(class_label, gts['bbox_mask'])]
else: else:
class_label = gts['class_label'] class_label = gts['class_label']
bbox = gts['bbox'] bbox = gts['bbox']
if self.discrete_output: if self.discrete_output:
lines_pred = preds['lines'].logits lines_pred = preds['lines'].logits
else: else:
lines_pred = preds['lines'] lines_pred = preds['lines']
bbox = [b.float() for b in bbox] bbox = [b.float() for b in bbox]
(labels_list, label_weights_list, (labels_list, label_weights_list,
lines_targets_list, lines_weights_list, lines_targets_list, lines_weights_list,
pos_inds_list, neg_inds_list,pos_gt_inds_list) = multi_apply( pos_inds_list, neg_inds_list,pos_gt_inds_list) = multi_apply(
self._get_target_single, self._get_target_single,
preds['scores'], lines_pred, preds['scores'], lines_pred,
class_label, bbox, class_label, bbox,
gt_bboxes_ignore=gt_bboxes_ignore_list) gt_bboxes_ignore=gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list))
new_gts = dict( new_gts = dict(
labels=labels_list, labels=labels_list,
label_weights=label_weights_list, label_weights=label_weights_list,
bboxs=lines_targets_list, bboxs=lines_targets_list,
bboxs_weights=lines_weights_list, bboxs_weights=lines_weights_list,
) )
return new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list return new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list
# @force_fp32(apply_to=('preds', 'gts')) # @force_fp32(apply_to=('preds', 'gts'))
def loss_single(self, def loss_single(self,
preds: dict, preds: dict,
gts: dict, gts: dict,
gt_bboxes_ignore_list=None, gt_bboxes_ignore_list=None,
reduction='none'): reduction='none'):
""" """
Loss function for outputs from a single decoder layer of a single Loss function for outputs from a single decoder layer of a single
feature level. feature level.
Args: Args:
cls_scores (Tensor): Box score logits from a single decoder layer cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels]. for all images. Shape [bs, num_query, cls_out_channels].
lines_preds (Tensor): lines_preds (Tensor):
shape [bs, num_query, num_points, 2]. shape [bs, num_query, num_points, 2].
gt_lines_list (list[Tensor]): gt_lines_list (list[Tensor]):
with shape (num_gts, num_points, 2) with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ). image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None. boxes which can be ignored for each image. Default None.
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer. a single decoder layer.
""" """
# Get target for each sample # Get target for each sample
new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list =\ new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list =\
self.get_targets(preds, gts, gt_bboxes_ignore_list) self.get_targets(preds, gts, gt_bboxes_ignore_list)
# Batched all data # Batched all data
for k, v in new_gts.items(): for k, v in new_gts.items():
new_gts[k] = torch.stack(v, dim=0) new_gts[k] = torch.stack(v, dim=0)
# construct weighted avg_factor to match with the official DETR repo # construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \ cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor: if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean( cls_avg_factor = reduce_mean(
preds['scores'].new_tensor([cls_avg_factor])) preds['scores'].new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1) cls_avg_factor = max(cls_avg_factor, 1)
# Classification loss # Classification loss
if self.separate_detect: if self.separate_detect:
loss_cls = self.bce_loss( loss_cls = self.bce_loss(
preds['scores'], new_gts['labels'], new_gts['label_weights'], cls_avg_factor) preds['scores'], new_gts['labels'], new_gts['label_weights'], cls_avg_factor)
else: else:
# since the inputs needs the second dim is the class dim, we permute the prediction. # since the inputs needs the second dim is the class dim, we permute the prediction.
cls_scores = preds['scores'].reshape(-1, self.cls_out_channels) cls_scores = preds['scores'].reshape(-1, self.cls_out_channels)
cls_labels = new_gts['labels'].reshape(-1) cls_labels = new_gts['labels'].reshape(-1)
cls_weights = new_gts['label_weights'].reshape(-1) cls_weights = new_gts['label_weights'].reshape(-1)
loss_cls = self.loss_cls( loss_cls = self.loss_cls(
cls_scores, cls_labels, cls_weights, avg_factor=cls_avg_factor) cls_scores, cls_labels, cls_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes accross all gpus, for # Compute the average number of gt boxes accross all gpus, for
# normalization purposes # normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos]) num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# position NLL loss # position NLL loss
if self.discrete_output: if self.discrete_output:
loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) * loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) *
new_gts['bboxs_weights']).sum()/(num_total_pos) new_gts['bboxs_weights']).sum()/(num_total_pos)
else: else:
loss_reg = self.reg_loss( loss_reg = self.reg_loss(
preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos) preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos)
loss_dict = dict( loss_dict = dict(
cls=loss_cls, cls=loss_cls,
reg=loss_reg, reg=loss_reg,
) )
return loss_dict, pos_inds_list, pos_gt_inds_list return loss_dict, pos_inds_list, pos_gt_inds_list
def bce_loss(self, logits, label, weights, cls_avg_factor): def bce_loss(self, logits, label, weights, cls_avg_factor):
''' binary ce plog(p) + (1-p)log(1-p) ''' binary ce plog(p) + (1-p)log(1-p)
logits: B,n,1 logits: B,n,1
label: label:
''' '''
p = logits.squeeze(-1).sigmoid() p = logits.squeeze(-1).sigmoid()
pos_msk = label == 0 pos_msk = label == 0
neg_msk = ~pos_msk neg_msk = ~pos_msk
loss_cls = -(p.log()*pos_msk + (1-p).log()*neg_msk) loss_cls = -(p.log()*pos_msk + (1-p).log()*neg_msk)
loss_cls = (loss_cls * weights).sum()/cls_avg_factor loss_cls = (loss_cls * weights).sum()/cls_avg_factor
return loss_cls return loss_cls
def post_process(self, preds_dicts: list, **kwargs): def post_process(self, preds_dicts: list, **kwargs):
''' '''
Args: Args:
preds_dicts: preds_dicts:
scores (Tensor): Classification score of all scores (Tensor): Classification score of all
decoder layers, has shape decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels]. [nb_dec, bs, num_query, cls_out_channels].
lines (Tensor): lines (Tensor):
[nb_dec, bs, num_query, bbox parameters(4)]. [nb_dec, bs, num_query, bbox parameters(4)].
Outs: Outs:
ret_list (List[Dict]) with length as bs ret_list (List[Dict]) with length as bs
list of result dict for each sample in the batch list of result dict for each sample in the batch
XXX XXX
''' '''
preds = preds_dicts[-1] preds = preds_dicts[-1]
batched_cls_scores = preds['scores'] batched_cls_scores = preds['scores']
batched_lines_preds = preds['lines'] batched_lines_preds = preds['lines']
batch_size = batched_cls_scores.size(0) batch_size = batched_cls_scores.size(0)
device = batched_cls_scores.device device = batched_cls_scores.device
result_dict = { result_dict = {
'bbox': [], 'bbox': [],
'scores': [], 'scores': [],
'labels': [], 'labels': [],
'bbox_flat': [], 'bbox_flat': [],
'lines_cls': [], 'lines_cls': [],
'lines_bs_idx': [], 'lines_bs_idx': [],
} }
for i in range(batch_size): for i in range(batch_size):
cls_scores = batched_cls_scores[i] cls_scores = batched_cls_scores[i]
det_preds = batched_lines_preds[i] det_preds = batched_lines_preds[i]
max_num = self.max_lines max_num = self.max_lines
if self.loss_cls.use_sigmoid: if self.loss_cls.use_sigmoid:
cls_scores = cls_scores.sigmoid() cls_scores = cls_scores.sigmoid()
scores, valid_idx = cls_scores.view(-1).topk(max_num) scores, valid_idx = cls_scores.view(-1).topk(max_num)
det_labels = valid_idx % self.num_classes det_labels = valid_idx % self.num_classes
valid_idx = valid_idx // self.num_classes valid_idx = valid_idx // self.num_classes
det_preds = det_preds[valid_idx] det_preds = det_preds[valid_idx]
else: else:
scores, det_labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1) scores, det_labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1)
scores, valid_idx = scores.topk(max_num) scores, valid_idx = scores.topk(max_num)
det_preds = det_preds[valid_idx] det_preds = det_preds[valid_idx]
det_labels = det_labels[valid_idx] det_labels = det_labels[valid_idx]
nline = len(valid_idx) nline = len(valid_idx)
result_dict['bbox'].append(det_preds) result_dict['bbox'].append(det_preds)
result_dict['scores'].append(scores) result_dict['scores'].append(scores)
result_dict['labels'].append(det_labels) result_dict['labels'].append(det_labels)
result_dict['lines_bs_idx'].extend([i]*nline) result_dict['lines_bs_idx'].extend([i]*nline)
# for down stream polyline # for down stream polyline
_bboxs = torch.cat(result_dict['bbox'], dim=0) _bboxs = torch.cat(result_dict['bbox'], dim=0)
# quantize the data # quantize the data
result_dict['bbox_flat'] = torch.round(_bboxs).type(torch.int32) result_dict['bbox_flat'] = torch.round(_bboxs).type(torch.int32)
result_dict['lines_cls'] = torch.cat( result_dict['lines_cls'] = torch.cat(
result_dict['labels'], dim=0).long() result_dict['labels'], dim=0).long()
result_dict['lines_bs_idx'] = torch.tensor( result_dict['lines_bs_idx'] = torch.tensor(
result_dict['lines_bs_idx'], device=device).long() result_dict['lines_bs_idx'], device=device).long()
return result_dict return result_dict
def assign_bev(feat, idx): def assign_bev(feat, idx):
return feat[idx] return feat[idx]
\ No newline at end of file
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import copy import copy
from mmdet.models import HEADS from mmdet.models import HEADS
from mmcv.cnn import Conv2d from mmcv.cnn import Conv2d
from mmcv.cnn import Linear, build_activation_layer, bias_init_with_prob from mmcv.cnn import Linear, build_activation_layer, bias_init_with_prob
from mmcv.cnn.bricks.transformer import build_positional_encoding from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils import build_transformer from mmdet.models.utils import build_transformer
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from mmdet.core import (multi_apply, build_assigner, build_sampler, from mmdet.core import (multi_apply, build_assigner, build_sampler,
reduce_mean) reduce_mean)
from mmdet.models.utils.transformer import inverse_sigmoid from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.models import build_loss from mmdet.models import build_loss
from .base_map_head import BaseMapHead from .base_map_head import BaseMapHead
@HEADS.register_module() @HEADS.register_module()
class DETRMapFixedNumHead(BaseMapHead): class DETRMapFixedNumHead(BaseMapHead):
def __init__(self, def __init__(self,
num_classes=3, num_classes=3,
in_channels=128, in_channels=128,
num_query=100, num_query=100,
max_lines=50, max_lines=50,
score_thre=0.2, score_thre=0.2,
num_reg_fcs=2, num_reg_fcs=2,
num_points=100, num_points=100,
iterative=False, iterative=False,
patch_size=None, patch_size=None,
sync_cls_avg_factor=True, sync_cls_avg_factor=True,
transformer: dict = None, transformer: dict = None,
positional_encoding: dict = None, positional_encoding: dict = None,
loss_cls: dict = None, loss_cls: dict = None,
loss_reg: dict = None, loss_reg: dict = None,
train_cfg: dict = None, train_cfg: dict = None,
init_cfg=None, init_cfg=None,
**kwargs): **kwargs):
super().__init__() super().__init__()
assigner = train_cfg['assigner'] assigner = train_cfg['assigner']
self.assigner = build_assigner(assigner) self.assigner = build_assigner(assigner)
# DETR sampling=False, so use PseudoSampler # DETR sampling=False, so use PseudoSampler
sampler_cfg = dict(type='PseudoSampler') sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self) self.sampler = build_sampler(sampler_cfg, context=self)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.max_lines = max_lines self.max_lines = max_lines
self.score_thre = score_thre self.score_thre = score_thre
self.num_query = num_query self.num_query = num_query
self.in_channels = in_channels self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
self.num_points = num_points self.num_points = num_points
# branch # branch
# if loss_cls.use_sigmoid: # if loss_cls.use_sigmoid:
if loss_cls['use_sigmoid']: if loss_cls['use_sigmoid']:
self.cls_out_channels = num_classes self.cls_out_channels = num_classes
else: else:
self.cls_out_channels = num_classes+1 self.cls_out_channels = num_classes+1
self.iterative = iterative self.iterative = iterative
self.num_reg_fcs = num_reg_fcs self.num_reg_fcs = num_reg_fcs
if patch_size is not None: if patch_size is not None:
self.register_buffer('patch_size', torch.tensor( self.register_buffer('patch_size', torch.tensor(
(patch_size[1], patch_size[0])),) (patch_size[1], patch_size[0])),)
self._build_transformer(transformer, positional_encoding) self._build_transformer(transformer, positional_encoding)
# loss params # loss params
self.loss_cls = build_loss(loss_cls) self.loss_cls = build_loss(loss_cls)
self.bg_cls_weight = 0.1 self.bg_cls_weight = 0.1
if self.loss_cls.use_sigmoid: if self.loss_cls.use_sigmoid:
self.bg_cls_weight = 0.0 self.bg_cls_weight = 0.0
self.sync_cls_avg_factor = sync_cls_avg_factor self.sync_cls_avg_factor = sync_cls_avg_factor
self.reg_loss = build_loss(loss_reg) self.reg_loss = build_loss(loss_reg)
# add reg, cls head for each decoder layer # add reg, cls head for each decoder layer
self._init_layers() self._init_layers()
self._init_branch() self._init_branch()
self.init_weights() self.init_weights()
def _init_layers(self): def _init_layers(self):
"""Initialize some layer.""" """Initialize some layer."""
self.input_proj = Conv2d( self.input_proj = Conv2d(
self.in_channels, self.embed_dims, kernel_size=1) self.in_channels, self.embed_dims, kernel_size=1)
# query_pos_embed & query_embed # query_pos_embed & query_embed
self.query_embedding = nn.Embedding(self.num_query, self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims) self.embed_dims)
def _build_transformer(self, transformer, positional_encoding): def _build_transformer(self, transformer, positional_encoding):
# transformer # transformer
self.act_cfg = transformer.get('act_cfg', self.act_cfg = transformer.get('act_cfg',
dict(type='ReLU', inplace=True)) dict(type='ReLU', inplace=True))
self.activate = build_activation_layer(self.act_cfg) self.activate = build_activation_layer(self.act_cfg)
self.positional_encoding = build_positional_encoding( self.positional_encoding = build_positional_encoding(
positional_encoding) positional_encoding)
self.transformer = build_transformer(transformer) self.transformer = build_transformer(transformer)
self.embed_dims = self.transformer.embed_dims self.embed_dims = self.transformer.embed_dims
def _init_branch(self,): def _init_branch(self,):
"""Initialize classification branch and regression branch of head.""" """Initialize classification branch and regression branch of head."""
fc_cls = Linear(self.embed_dims, self.cls_out_channels) fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch = [] reg_branch = []
for _ in range(self.num_reg_fcs): for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims)) reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.LayerNorm(self.embed_dims)) reg_branch.append(nn.LayerNorm(self.embed_dims))
reg_branch.append(nn.ReLU()) reg_branch.append(nn.ReLU())
reg_branch.append(Linear(self.embed_dims, self.num_points*2)) reg_branch.append(Linear(self.embed_dims, self.num_points*2))
reg_branch = nn.Sequential(*reg_branch) reg_branch = nn.Sequential(*reg_branch)
# add sigmoid or not # add sigmoid or not
def _get_clones(module, N): def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
num_pred = self.transformer.decoder.num_layers num_pred = self.transformer.decoder.num_layers
if self.iterative: if self.iterative:
fc_cls = _get_clones(fc_cls, num_pred) fc_cls = _get_clones(fc_cls, num_pred)
reg_branch = _get_clones(reg_branch, num_pred) reg_branch = _get_clones(reg_branch, num_pred)
self.pre_branches = nn.ModuleDict([ self.pre_branches = nn.ModuleDict([
('cls', fc_cls), ('cls', fc_cls),
('reg', reg_branch), ]) ('reg', reg_branch), ])
def init_weights(self): def init_weights(self):
"""Initialize weights of the DeformDETR head.""" """Initialize weights of the DeformDETR head."""
for p in self.input_proj.parameters(): for p in self.input_proj.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
self.transformer.init_weights() self.transformer.init_weights()
# init prediction branch # init prediction branch
for k, v in self.pre_branches.items(): for k, v in self.pre_branches.items():
for param in v.parameters(): for param in v.parameters():
if param.dim() > 1: if param.dim() > 1:
nn.init.xavier_uniform_(param) nn.init.xavier_uniform_(param)
# focal loss init # focal loss init
if self.loss_cls.use_sigmoid: if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01) bias_init = bias_init_with_prob(0.01)
# for last layer # for last layer
if isinstance(self.pre_branches['cls'], nn.ModuleList): if isinstance(self.pre_branches['cls'], nn.ModuleList):
for m in self.pre_branches['cls']: for m in self.pre_branches['cls']:
nn.init.constant_(m.bias, bias_init) nn.init.constant_(m.bias, bias_init)
else: else:
m = self.pre_branches['cls'] m = self.pre_branches['cls']
nn.init.constant_(m.bias, bias_init) nn.init.constant_(m.bias, bias_init)
def forward(self, bev_feature, img_metas=None): def forward(self, bev_feature, img_metas=None):
''' '''
Args: Args:
bev_feature (List[Tensor]): shape [B, C, H, W] bev_feature (List[Tensor]): shape [B, C, H, W]
feature in bev view feature in bev view
img_metas img_metas
Outs: Outs:
preds_dict (Dict): preds_dict (Dict):
all_cls_scores (Tensor): Classification score of all all_cls_scores (Tensor): Classification score of all
decoder layers, has shape decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels]. [nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor): all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2]. [nb_dec, bs, num_query, num_points, 2].
''' '''
x = bev_feature[0] x = bev_feature[0]
x = self.input_proj(x) # only change feature size x = self.input_proj(x) # only change feature size
B, C, H, W = x.shape B, C, H, W = x.shape
masks = x.new_zeros((B, H, W)) masks = x.new_zeros((B, H, W))
pos_embed = self.positional_encoding(masks) pos_embed = self.positional_encoding(masks)
# outs_dec: [nb_dec, bs, num_query, embed_dim] # outs_dec: [nb_dec, bs, num_query, embed_dim]
outs_dec, _ = self.transformer(x, masks.type(torch.bool), self.query_embedding.weight, outs_dec, _ = self.transformer(x, masks.type(torch.bool), self.query_embedding.weight,
pos_embed) pos_embed)
outputs = [] outputs = []
for i, query_feat in enumerate(outs_dec): for i, query_feat in enumerate(outs_dec):
ocls = self.pre_branches['cls'](query_feat) ocls = self.pre_branches['cls'](query_feat)
oreg = self.pre_branches['reg'](query_feat) oreg = self.pre_branches['reg'](query_feat)
oreg = oreg.unflatten(dim=2, sizes=(self.num_points, 2)) oreg = oreg.unflatten(dim=2, sizes=(self.num_points, 2))
oreg[..., 0:2] = oreg[..., 0:2].sigmoid() # normalized xyz oreg[..., 0:2] = oreg[..., 0:2].sigmoid() # normalized xyz
outputs.append( outputs.append(
dict( dict(
lines=oreg, # [bs, num_query, num_points, 2] lines=oreg, # [bs, num_query, num_points, 2]
scores=ocls, # [bs, num_query, num_class] scores=ocls, # [bs, num_query, num_class]
) )
) )
return outputs return outputs
@force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines')) @force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines'))
def _get_target_single(self, def _get_target_single(self,
score_pred, score_pred,
lines_pred, lines_pred,
gt_lines, gt_lines,
gt_labels, gt_labels,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
""" """
Compute regression and classification targets for one image. Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used. Outputs from a single decoder layer of a single feature level are used.
Args: Args:
cls_score (Tensor): Box score logits from a single decoder layer cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels]. for one image. Shape [num_query, cls_out_channels].
lines_pred (Tensor): lines_pred (Tensor):
shape [num_query, num_points, 2]. shape [num_query, num_points, 2].
gt_lines (Tensor): gt_lines (Tensor):
shape [num_gt, num_points, 2]. shape [num_gt, num_points, 2].
gt_labels (torch.LongTensor) gt_labels (torch.LongTensor)
shape [num_gt, ] shape [num_gt, ]
Returns: Returns:
tuple[Tensor]: a tuple containing the following for one image. tuple[Tensor]: a tuple containing the following for one image.
- labels (LongTensor): Labels of each image. - labels (LongTensor): Labels of each image.
shape [num_query, 1] shape [num_query, 1]
- label_weights (Tensor]): Label weights of each image. - label_weights (Tensor]): Label weights of each image.
shape [num_query, 1] shape [num_query, 1]
- lines_target (Tensor): Lines targets of each image. - lines_target (Tensor): Lines targets of each image.
shape [num_query, num_points, 2] shape [num_query, num_points, 2]
- lines_weights (Tensor): Lines weights of each image. - lines_weights (Tensor): Lines weights of each image.
shape [num_query, num_points, 2] shape [num_query, num_points, 2]
- pos_inds (Tensor): Sampled positive indices for each image. - pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image. - neg_inds (Tensor): Sampled negative indices for each image.
""" """
num_pred_lines = lines_pred.size(0) num_pred_lines = lines_pred.size(0)
# assigner and sampler # assigner and sampler
assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,), assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,),
gts=dict(lines=gt_lines, gts=dict(lines=gt_lines,
labels=gt_labels, ), labels=gt_labels, ),
gt_bboxes_ignore=gt_bboxes_ignore) gt_bboxes_ignore=gt_bboxes_ignore)
sampling_result = self.sampler.sample( sampling_result = self.sampler.sample(
assign_result, lines_pred, gt_lines) assign_result, lines_pred, gt_lines)
pos_inds = sampling_result.pos_inds pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds neg_inds = sampling_result.neg_inds
# label targets # label targets
labels = gt_lines.new_full((num_pred_lines, ), labels = gt_lines.new_full((num_pred_lines, ),
self.num_classes, self.num_classes,
dtype=torch.long) dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_lines.new_ones(num_pred_lines) label_weights = gt_lines.new_ones(num_pred_lines)
# bbox targets # bbox targets
lines_target = torch.zeros_like(lines_pred) lines_target = torch.zeros_like(lines_pred)
lines_target[pos_inds] = sampling_result.pos_gt_bboxes lines_target[pos_inds] = sampling_result.pos_gt_bboxes
lines_weights = torch.zeros_like(lines_pred) lines_weights = torch.zeros_like(lines_pred)
lines_weights[pos_inds] = 1.0 lines_weights[pos_inds] = 1.0
return (labels, label_weights, lines_target, lines_weights, return (labels, label_weights, lines_target, lines_weights,
pos_inds, neg_inds) pos_inds, neg_inds)
@force_fp32(apply_to=('preds', 'gts')) @force_fp32(apply_to=('preds', 'gts'))
def get_targets(self, preds, gts, gt_bboxes_ignore_list=None): def get_targets(self, preds, gts, gt_bboxes_ignore_list=None):
""" """
Compute regression and classification targets for a batch image. Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used. Outputs from a single decoder layer of a single feature level are used.
Args: Args:
cls_scores_list (list[Tensor]): Box score logits from a single cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query, decoder layer for each image with shape [num_query,
cls_out_channels]. cls_out_channels].
lines_preds_list (list[Tensor]): [num_query, num_points, 2]. lines_preds_list (list[Tensor]): [num_query, num_points, 2].
gt_lines_list (list[Tensor]): Ground truth lines for each image gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2) with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ). image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None. boxes which can be ignored for each image. Default None.
Returns: Returns:
tuple: a tuple containing the following targets. tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images. - labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \ - label_weights_list (list[Tensor]): Label weights for all \
images. images.
- lines_targets_list (list[Tensor]): Lines targets for all \ - lines_targets_list (list[Tensor]): Lines targets for all \
images. images.
- lines_weight_list (list[Tensor]): Lines weights for all \ - lines_weight_list (list[Tensor]): Lines weights for all \
images. images.
- num_total_pos (int): Number of positive samples in all \ - num_total_pos (int): Number of positive samples in all \
images. images.
- num_total_neg (int): Number of negative samples in all \ - num_total_neg (int): Number of negative samples in all \
images. images.
""" """
assert gt_bboxes_ignore_list is None, \ assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.' 'Only supports for gt_bboxes_ignore setting to None.'
(labels_list, label_weights_list, (labels_list, label_weights_list,
lines_targets_list, lines_weights_list, lines_targets_list, lines_weights_list,
pos_inds_list, neg_inds_list) = multi_apply( pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single, self._get_target_single,
preds['scores'], preds['lines'], preds['scores'], preds['lines'],
gts['lines'], gts['labels'], gts['lines'], gts['labels'],
gt_bboxes_ignore=gt_bboxes_ignore_list) gt_bboxes_ignore=gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list))
new_gts = dict( new_gts = dict(
labels=labels_list, labels=labels_list,
label_weights=label_weights_list, label_weights=label_weights_list,
lines_targets=lines_targets_list, lines_targets=lines_targets_list,
lines_weights=lines_weights_list, lines_weights=lines_weights_list,
) )
return new_gts, num_total_pos, num_total_neg, pos_inds_list return new_gts, num_total_pos, num_total_neg, pos_inds_list
@force_fp32(apply_to=('preds', 'gts')) @force_fp32(apply_to=('preds', 'gts'))
def loss_single(self, def loss_single(self,
preds: dict, preds: dict,
gts: dict, gts: dict,
gt_bboxes_ignore_list=None, gt_bboxes_ignore_list=None,
reduction='none'): reduction='none'):
""" """
Loss function for outputs from a single decoder layer of a single Loss function for outputs from a single decoder layer of a single
feature level. feature level.
Args: Args:
cls_scores (Tensor): Box score logits from a single decoder layer cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels]. for all images. Shape [bs, num_query, cls_out_channels].
lines_preds (Tensor): lines_preds (Tensor):
shape [bs, num_query, num_points, 2]. shape [bs, num_query, num_points, 2].
gt_lines_list (list[Tensor]): gt_lines_list (list[Tensor]):
with shape (num_gts, num_points, 2) with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ). image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None. boxes which can be ignored for each image. Default None.
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer. a single decoder layer.
""" """
# get target for each sample # get target for each sample
new_gts, num_total_pos, num_total_neg, pos_inds_list =\ new_gts, num_total_pos, num_total_neg, pos_inds_list =\
self.get_targets(preds, gts, gt_bboxes_ignore_list) self.get_targets(preds, gts, gt_bboxes_ignore_list)
# batched all data # batched all data
for k, v in new_gts.items(): for k, v in new_gts.items():
new_gts[k] = torch.cat(v, 0) new_gts[k] = torch.cat(v, 0)
# construct weighted avg_factor to match with the official DETR repo # construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \ cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor: if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean( cls_avg_factor = reduce_mean(
preds['scores'].new_tensor([cls_avg_factor])) preds['scores'].new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1) cls_avg_factor = max(cls_avg_factor, 1)
# classification loss # classification loss
cls_scores = preds['scores'].reshape(-1, self.cls_out_channels) cls_scores = preds['scores'].reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls( loss_cls = self.loss_cls(
cls_scores, new_gts['labels'], new_gts['label_weights'], avg_factor=cls_avg_factor) cls_scores, new_gts['labels'], new_gts['label_weights'], avg_factor=cls_avg_factor)
# Compute the average number of gt boxes accross all gpus, for # Compute the average number of gt boxes accross all gpus, for
# normalization purposes # normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos]) num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# regression L1 loss # regression L1 loss
lines_preds = preds['lines'].reshape(-1, self.num_points, 2) lines_preds = preds['lines'].reshape(-1, self.num_points, 2)
if reduction == 'none': # For performance analysis if reduction == 'none': # For performance analysis
loss_reg = self.reg_loss( loss_reg = self.reg_loss(
lines_preds, new_gts['lines_targets'], new_gts['lines_weights'], reduction_override=reduction, avg_factor=num_total_pos) lines_preds, new_gts['lines_targets'], new_gts['lines_weights'], reduction_override=reduction, avg_factor=num_total_pos)
else: else:
loss_reg = self.reg_loss( loss_reg = self.reg_loss(
lines_preds, new_gts['lines_targets'], new_gts['lines_weights'], avg_factor=num_total_pos) lines_preds, new_gts['lines_targets'], new_gts['lines_weights'], avg_factor=num_total_pos)
loss_dict = dict( loss_dict = dict(
cls=loss_cls, cls=loss_cls,
reg=loss_reg, reg=loss_reg,
) )
return (loss_dict, pos_inds_list) return (loss_dict, pos_inds_list)
@force_fp32(apply_to=('gt_lines_list', 'preds_dicts')) @force_fp32(apply_to=('gt_lines_list', 'preds_dicts'))
def loss(self, def loss(self,
gts: dict, gts: dict,
preds_dicts: dict, preds_dicts: dict,
gt_bboxes_ignore=None, gt_bboxes_ignore=None,
reduction='mean'): reduction='mean'):
""" """
Loss Function. Loss Function.
Args: Args:
gt_lines_list (list[Tensor]): Ground truth lines for each image gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2) with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ). image with shape (num_gts, ).
preds_dicts: preds_dicts:
all_cls_scores (Tensor): Classification score of all all_cls_scores (Tensor): Classification score of all
decoder layers, has shape decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels]. [nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor): all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2]. [nb_dec, bs, num_query, num_points, 2].
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None. which can be ignored for each image. Default None.
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components. dict[str, Tensor]: A dictionary of loss components.
""" """
assert gt_bboxes_ignore is None, \ assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \ f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.' f'for gt_bboxes_ignore setting to None.'
# Since there might have multi layer # Since there might have multi layer
losses, pos_inds_lists, pos_gt_inds_lists = multi_apply( losses, pos_inds_lists, pos_gt_inds_lists = multi_apply(
self.loss_single, self.loss_single,
preds_dicts, preds_dicts,
gts=gts, gts=gts,
gt_bboxes_ignore_list=gt_bboxes_ignore, gt_bboxes_ignore_list=gt_bboxes_ignore,
reduction=reduction) reduction=reduction)
# Format the losses # Format the losses
loss_dict = dict() loss_dict = dict()
# loss from the last decoder layer # loss from the last decoder layer
for k, v in losses[-1].items(): for k, v in losses[-1].items():
loss_dict[k] = v loss_dict[k] = v
# Loss from other decoder layers # Loss from other decoder layers
num_dec_layer = 0 num_dec_layer = 0
for loss in losses[:-1]: for loss in losses[:-1]:
for k, v in loss.items(): for k, v in loss.items():
loss_dict[f'd{num_dec_layer}.{k}'] = v loss_dict[f'd{num_dec_layer}.{k}'] = v
num_dec_layer += 1 num_dec_layer += 1
return loss_dict, pos_inds_lists, pos_gt_inds_lists return loss_dict, pos_inds_lists, pos_gt_inds_lists
def post_process(self, preds_dict, tokens, gts): def post_process(self, preds_dict, tokens, gts):
''' '''
Args: Args:
preds_dict: preds_dict:
all_cls_scores (Tensor): Classification score of all all_cls_scores (Tensor): Classification score of all
decoder layers, has shape decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels]. [nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor): all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2]. [nb_dec, bs, num_query, num_points, 2].
Outs: Outs:
ret_list (List[Dict]) with length as bs ret_list (List[Dict]) with length as bs
list of result dict for each sample in the batch list of result dict for each sample in the batch
Dict keys: Dict keys:
'lines': numpy.array of shape [num_pred, num_points, 2] 'lines': numpy.array of shape [num_pred, num_points, 2]
'scores': numpy.array of shape [num_pred, ] 'scores': numpy.array of shape [num_pred, ]
after sigmoid after sigmoid
'labels': numpy.array of shape [num_pred, ] 'labels': numpy.array of shape [num_pred, ]
dtype=long dtype=long
''' '''
preds = preds_dict[-1] preds = preds_dict[-1]
batched_cls_scores = preds['scores'] batched_cls_scores = preds['scores']
batched_lines_preds = preds['lines'] batched_lines_preds = preds['lines']
batch_size = batched_cls_scores.size(0) batch_size = batched_cls_scores.size(0)
ret_list = [] ret_list = []
for i in range(len(tokens)): for i in range(len(tokens)):
cls_scores = batched_cls_scores[i] cls_scores = batched_cls_scores[i]
lines_preds = batched_lines_preds[i] lines_preds = batched_lines_preds[i]
max_num = self.max_lines max_num = self.max_lines
if cls_scores.shape[-1] > self.num_classes: if cls_scores.shape[-1] > self.num_classes:
scores, labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1) scores, labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1)
final_scores, bbox_index = scores.topk(self.max_lines) final_scores, bbox_index = scores.topk(self.max_lines)
final_lines = lines_preds[bbox_index] final_lines = lines_preds[bbox_index]
final_labels = labels[bbox_index] final_labels = labels[bbox_index]
else: else:
cls_scores = cls_scores.sigmoid() cls_scores = cls_scores.sigmoid()
final_scores, indexes = cls_scores.view(-1).topk(self.max_lines) final_scores, indexes = cls_scores.view(-1).topk(self.max_lines)
final_labels = indexes % self.num_classes final_labels = indexes % self.num_classes
bbox_index = indexes // self.num_classes bbox_index = indexes // self.num_classes
final_lines = lines_preds[bbox_index] final_lines = lines_preds[bbox_index]
ret_dict_single = { ret_dict_single = {
'token': tokens[i], 'token': tokens[i],
'lines': final_lines.detach().cpu().numpy() * 2 - 1, 'lines': final_lines.detach().cpu().numpy() * 2 - 1,
'scores': final_scores.detach().cpu().numpy(), 'scores': final_scores.detach().cpu().numpy(),
'labels': final_labels.detach().cpu().numpy(), 'labels': final_labels.detach().cpu().numpy(),
'nline': len(final_lines), 'nline': len(final_lines),
} }
if gts is not None: if gts is not None:
lines_gt = gts['lines'][i].detach().cpu().numpy() lines_gt = gts['lines'][i].detach().cpu().numpy()
labels_gt = gts['labels'][i].detach().cpu().numpy() labels_gt = gts['labels'][i].detach().cpu().numpy()
ret_dict_single['groundTruth'] = { ret_dict_single['groundTruth'] = {
'token': tokens[i], 'token': tokens[i],
'nline': lines_gt.shape[0], 'nline': lines_gt.shape[0],
'labels': labels_gt, 'labels': labels_gt,
'lines': lines_gt * 2 - 1, 'lines': lines_gt * 2 - 1,
} }
# if (labels_gt==1).any(): # if (labels_gt==1).any():
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()
ret_list.append(ret_dict_single) ret_list.append(ret_dict_single)
return ret_list return ret_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment