Commit 41b18fd8 authored by zhe chen's avatar zhe chen
Browse files

Use pre-commit to reformat code


Use pre-commit to reformat code
parent ff20ea39
...@@ -2,6 +2,7 @@ import mmcv ...@@ -2,6 +2,7 @@ import mmcv
import numpy as np import numpy as np
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module(force=True) @PIPELINES.register_module(force=True)
class LoadMultiViewImagesFromFiles(object): class LoadMultiViewImagesFromFiles(object):
"""Load multi channel images from a list of separate channel files. """Load multi channel images from a list of separate channel files.
...@@ -56,5 +57,5 @@ class LoadMultiViewImagesFromFiles(object): ...@@ -56,5 +57,5 @@ class LoadMultiViewImagesFromFiles(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
return f'{self.__class__.__name__} (to_float32={self.to_float32}, '\ return f'{self.__class__.__name__} (to_float32={self.to_float32}, ' \
f"color_type='{self.color_type}')" f"color_type='{self.color_type}')"
import numpy as np import numpy as np
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from shapely.geometry import LineString from shapely.geometry import LineString
@PIPELINES.register_module(force=True) @PIPELINES.register_module(force=True)
class PolygonizeLocalMapBbox(object): class PolygonizeLocalMapBbox(object):
"""Pre-Processing used by vectormapnet model. """Pre-Processing used by vectormapnet model.
...@@ -18,7 +18,7 @@ class PolygonizeLocalMapBbox(object): ...@@ -18,7 +18,7 @@ class PolygonizeLocalMapBbox(object):
canvas_size=(200, 100), canvas_size=(200, 100),
coord_dim=2, coord_dim=2,
num_class=3, num_class=3,
threshold=6/200, threshold=6 / 200,
): ):
self.canvas_size = np.array(canvas_size) self.canvas_size = np.array(canvas_size)
...@@ -47,7 +47,7 @@ class PolygonizeLocalMapBbox(object): ...@@ -47,7 +47,7 @@ class PolygonizeLocalMapBbox(object):
polyline_weight = np.ones_like(polyline).reshape(-1) polyline_weight = np.ones_like(polyline).reshape(-1)
polyline_weight = np.pad( polyline_weight = np.pad(
polyline_weight, ((0, 1),), constant_values=1.) polyline_weight, ((0, 1),), constant_values=1.)
polyline_weight = polyline_weight/polyline_weight.sum() polyline_weight = polyline_weight / polyline_weight.sum()
# flatten and quantilized # flatten and quantilized
fpolyline = quantize_verts( fpolyline = quantize_verts(
...@@ -58,7 +58,7 @@ class PolygonizeLocalMapBbox(object): ...@@ -58,7 +58,7 @@ class PolygonizeLocalMapBbox(object):
# reindex starting from 1, and add a zero stopping token(EOS), # reindex starting from 1, and add a zero stopping token(EOS),
fpolyline = \ fpolyline = \
np.pad(fpolyline + self.coord_dim_start_idx, ((0, 1),), np.pad(fpolyline + self.coord_dim_start_idx, ((0, 1),),
constant_values=0) constant_values=0)
fpolyline_msk = np.ones(fpolyline.shape, dtype=np.bool) fpolyline_msk = np.ones(fpolyline.shape, dtype=np.bool)
polyline_masks.append(fpolyline_msk) polyline_masks.append(fpolyline_msk)
...@@ -98,11 +98,11 @@ class PolygonizeLocalMapBbox(object): ...@@ -98,11 +98,11 @@ class PolygonizeLocalMapBbox(object):
qkp_msks = np.stack(qkp_masks) qkp_msks = np.stack(qkp_masks)
# format det # format det
kps = np.stack(kps, axis=0).astype(np.float32)*self.canvas_size kps = np.stack(kps, axis=0).astype(np.float32) * self.canvas_size
kp_labels = np.array(kp_labels) kp_labels = np.array(kp_labels)
# restrict the boundary # restrict the boundary
kps[..., 0] = np.clip(kps[..., 0], 0.1, self.canvas_size[0]-0.1) kps[..., 0] = np.clip(kps[..., 0], 0.1, self.canvas_size[0] - 0.1)
kps[..., 1] = np.clip(kps[..., 1], 0.1, self.canvas_size[1]-0.1) kps[..., 1] = np.clip(kps[..., 1], 0.1, self.canvas_size[1] - 0.1)
# nbox, boxsize(4)*coord_dim(2) # nbox, boxsize(4)*coord_dim(2)
kps = kps.reshape(kps.shape[0], -1) kps = kps.reshape(kps.shape[0], -1)
...@@ -114,7 +114,7 @@ class PolygonizeLocalMapBbox(object): ...@@ -114,7 +114,7 @@ class PolygonizeLocalMapBbox(object):
''' '''
Process vertices. Process vertices.
''' '''
vectors = input_dict['vectors'] vectors = input_dict['vectors']
n_lines = 0 n_lines = 0
...@@ -157,10 +157,9 @@ class PolygonizeLocalMapBbox(object): ...@@ -157,10 +157,9 @@ class PolygonizeLocalMapBbox(object):
def evaluate_line(polyline): def evaluate_line(polyline):
edge = np.linalg.norm(polyline[1:] - polyline[:-1], axis=-1) edge = np.linalg.norm(polyline[1:] - polyline[:-1], axis=-1)
start_end_weight = edge[(0, -1), ].copy() start_end_weight = edge[(0, -1),].copy()
mid_weight = (edge[:-1] + edge[1:]) * .5 mid_weight = (edge[:-1] + edge[1:]) * .5
pts_weight = np.concatenate( pts_weight = np.concatenate(
...@@ -172,16 +171,16 @@ def evaluate_line(polyline): ...@@ -172,16 +171,16 @@ def evaluate_line(polyline):
pts_weight /= denominator pts_weight /= denominator
# add weights for stop index # add weights for stop index
pts_weight = np.repeat(pts_weight, 2)/2 pts_weight = np.repeat(pts_weight, 2) / 2
pts_weight = np.pad(pts_weight, ((0, 1)), pts_weight = np.pad(pts_weight, ((0, 1)),
constant_values=1/(len(polyline)*2)) constant_values=1 / (len(polyline) * 2))
return pts_weight return pts_weight
def quantize_verts(verts, canvas_size, coord_dim): def quantize_verts(verts, canvas_size, coord_dim):
"""Convert vertices from its original range ([-1,1]) to discrete values in [0, n_bits**2 - 1]. """Convert vertices from its original range ([-1,1]) to discrete values in [0, n_bits**2 - 1].
Args: Args:
verts (array): vertices coordinates, shape (seqlen, coords_dim) verts (array): vertices coordinates, shape (seqlen, coords_dim)
canvas_size (tuple): bev feature size canvas_size (tuple): bev feature size
...@@ -196,7 +195,7 @@ def quantize_verts(verts, canvas_size, coord_dim): ...@@ -196,7 +195,7 @@ def quantize_verts(verts, canvas_size, coord_dim):
range_quantize = np.array(canvas_size) - 1 # (0-199) = 200 range_quantize = np.array(canvas_size) - 1 # (0-199) = 200
verts_ratio = (verts[:, :coord_dim] - min_range) / ( verts_ratio = (verts[:, :coord_dim] - min_range) / (
max_range - min_range) max_range - min_range)
verts_quantize = verts_ratio * range_quantize[:coord_dim] verts_quantize = verts_ratio * range_quantize[:coord_dim]
return verts_quantize.astype('int32') return verts_quantize.astype('int32')
...@@ -204,11 +203,11 @@ def quantize_verts(verts, canvas_size, coord_dim): ...@@ -204,11 +203,11 @@ def quantize_verts(verts, canvas_size, coord_dim):
def get_bbox(polyline, threshold): def get_bbox(polyline, threshold):
"""Convert vertices from its original range ([-1,1]) to discrete values in [0, n_bits**2 - 1]. """Convert vertices from its original range ([-1,1]) to discrete values in [0, n_bits**2 - 1].
Args: Args:
polyline (array): point coordinates, shape (seqlen, 2) polyline (array): point coordinates, shape (seqlen, 2)
threshold (float): threshold for minimum bbox size threshold (float): threshold for minimum bbox size
Returns: Returns:
bbox (array): bounding box in xyxy format, shape (2, 2) bbox (array): bounding box in xyxy format, shape (2, 2)
""" """
...@@ -216,14 +215,14 @@ def get_bbox(polyline, threshold): ...@@ -216,14 +215,14 @@ def get_bbox(polyline, threshold):
polyline = LineString(polyline) polyline = LineString(polyline)
bbox = polyline.bounds bbox = polyline.bounds
minx, miny, maxx, maxy = bbox minx, miny, maxx, maxy = bbox
W, H = maxx-minx, maxy-miny W, H = maxx - minx, maxy - miny
if W < threshold or H < threshold: if W < threshold or H < threshold:
remain = max((threshold - min(W, H))/2, eps) remain = max((threshold - min(W, H)) / 2, eps)
bbox = polyline.buffer(remain).envelope.bounds bbox = polyline.buffer(remain).envelope.bounds
minx, miny, maxx, maxy = bbox minx, miny, maxx, maxy = bbox
bbox_np = np.array([[minx, miny], [maxx, maxy]]) bbox_np = np.array([[minx, miny], [maxx, maxy]])
bbox_np = np.clip(bbox_np, 0., 1.) bbox_np = np.clip(bbox_np, 0., 1.)
return bbox_np return bbox_np
\ No newline at end of file
import numpy as np
import mmcv import mmcv
import numpy as np
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
...@@ -82,26 +81,26 @@ class PadMultiViewImages(object): ...@@ -82,26 +81,26 @@ class PadMultiViewImages(object):
if self.change_intrinsics: if self.change_intrinsics:
post_intrinsics, post_ego2imgs = [], [] post_intrinsics, post_ego2imgs = [], []
for img, oshape, cam_intrinsic, ego2img in zip(results['img'], \ for img, oshape, cam_intrinsic, ego2img in zip(results['img'], \
original_shape, results['cam_intrinsics'], results['ego2img']): original_shape, results['cam_intrinsics'],
results['ego2img']):
scaleW = img.shape[1] / oshape[1] scaleW = img.shape[1] / oshape[1]
scaleH = img.shape[0] / oshape[0] scaleH = img.shape[0] / oshape[0]
rot_resize_matrix = np.array([ rot_resize_matrix = np.array([
[scaleW, 0, 0, 0], [scaleW, 0, 0, 0],
[0, scaleH, 0, 0], [0, scaleH, 0, 0],
[0, 0, 1, 0], [0, 0, 1, 0],
[0, 0, 0, 1]]) [0, 0, 0, 1]])
post_intrinsic = rot_resize_matrix[:3, :3] @ cam_intrinsic post_intrinsic = rot_resize_matrix[:3, :3] @ cam_intrinsic
post_ego2img = rot_resize_matrix @ ego2img post_ego2img = rot_resize_matrix @ ego2img
post_intrinsics.append(post_intrinsic) post_intrinsics.append(post_intrinsic)
post_ego2imgs.append(post_ego2img) post_ego2imgs.append(post_ego2img)
results.update({ results.update({
'cam_intrinsics': post_intrinsics, 'cam_intrinsics': post_intrinsics,
'ego2img': post_ego2imgs, 'ego2img': post_ego2imgs,
}) })
results['img_shape'] = [img.shape for img in padded_img] results['img_shape'] = [img.shape for img in padded_img]
results['img_fixed_size'] = self.size results['img_fixed_size'] = self.size
results['img_size_divisor'] = self.size_divisor results['img_size_divisor'] = self.size_divisor
...@@ -135,16 +134,17 @@ class ResizeMultiViewImages(object): ...@@ -135,16 +134,17 @@ class ResizeMultiViewImages(object):
size (tuple, optional): resize target size, (h, w). size (tuple, optional): resize target size, (h, w).
change_intrinsics (bool): whether to update intrinsics. change_intrinsics (bool): whether to update intrinsics.
""" """
def __init__(self, size, change_intrinsics=True): def __init__(self, size, change_intrinsics=True):
self.size = size self.size = size
self.change_intrinsics = change_intrinsics self.change_intrinsics = change_intrinsics
def __call__(self, results:dict): def __call__(self, results: dict):
new_imgs, post_intrinsics, post_ego2imgs = [], [], [] new_imgs, post_intrinsics, post_ego2imgs = [], [], []
for img, cam_intrinsic, ego2img in zip(results['img'], \ for img, cam_intrinsic, ego2img in zip(results['img'], \
results['cam_intrinsics'], results['ego2img']): results['cam_intrinsics'], results['ego2img']):
tmp, scaleW, scaleH = mmcv.imresize(img, tmp, scaleW, scaleH = mmcv.imresize(img,
# NOTE: mmcv.imresize expect (w, h) shape # NOTE: mmcv.imresize expect (w, h) shape
(self.size[1], self.size[0]), (self.size[1], self.size[0]),
...@@ -152,10 +152,10 @@ class ResizeMultiViewImages(object): ...@@ -152,10 +152,10 @@ class ResizeMultiViewImages(object):
new_imgs.append(tmp) new_imgs.append(tmp)
rot_resize_matrix = np.array([ rot_resize_matrix = np.array([
[scaleW, 0, 0, 0], [scaleW, 0, 0, 0],
[0, scaleH, 0, 0], [0, scaleH, 0, 0],
[0, 0, 1, 0], [0, 0, 1, 0],
[0, 0, 0, 1]]) [0, 0, 0, 1]])
post_intrinsic = rot_resize_matrix[:3, :3] @ cam_intrinsic post_intrinsic = rot_resize_matrix[:3, :3] @ cam_intrinsic
post_ego2img = rot_resize_matrix @ ego2img post_ego2img = rot_resize_matrix @ ego2img
post_intrinsics.append(post_intrinsic) post_intrinsics.append(post_intrinsic)
...@@ -170,10 +170,10 @@ class ResizeMultiViewImages(object): ...@@ -170,10 +170,10 @@ class ResizeMultiViewImages(object):
}) })
return results return results
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, ' repr_str += f'(size={self.size}, '
repr_str += f'change_intrinsics={self.change_intrinsics})' repr_str += f'change_intrinsics={self.change_intrinsics})'
return repr_str return repr_str
\ No newline at end of file
from typing import Dict, List, Tuple, Union
import numpy as np import numpy as np
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from shapely.geometry import LineString
from numpy.typing import NDArray from numpy.typing import NDArray
from typing import List, Tuple, Union, Dict from shapely.geometry import LineString
@PIPELINES.register_module(force=True) @PIPELINES.register_module(force=True)
class VectorizeMap(object): class VectorizeMap(object):
...@@ -20,14 +22,14 @@ class VectorizeMap(object): ...@@ -20,14 +22,14 @@ class VectorizeMap(object):
sample_dist (float): interpolate distance. Set to -1 to ignore. sample_dist (float): interpolate distance. Set to -1 to ignore.
""" """
def __init__(self, def __init__(self,
roi_size: Union[Tuple, List], roi_size: Union[Tuple, List],
normalize: bool, normalize: bool,
coords_dim: int, coords_dim: int,
simplify: bool=False, simplify: bool = False,
sample_num: int=-1, sample_num: int = -1,
sample_dist: float=-1, sample_dist: float = -1,
): ):
self.coords_dim = coords_dim self.coords_dim = coords_dim
self.sample_num = sample_num self.sample_num = sample_num
self.sample_dist = sample_dist self.sample_dist = sample_dist
...@@ -45,46 +47,46 @@ class VectorizeMap(object): ...@@ -45,46 +47,46 @@ class VectorizeMap(object):
def interp_fixed_num(self, line: LineString) -> NDArray: def interp_fixed_num(self, line: LineString) -> NDArray:
''' Interpolate a line to fixed number of points. ''' Interpolate a line to fixed number of points.
Args: Args:
line (LineString): line line (LineString): line
Returns: Returns:
points (array): interpolated points, shape (N, 2) points (array): interpolated points, shape (N, 2)
''' '''
distances = np.linspace(0, line.length, self.sample_num) distances = np.linspace(0, line.length, self.sample_num)
sampled_points = np.array([list(line.interpolate(distance).coords) sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze() for distance in distances]).squeeze()
return sampled_points return sampled_points
def interp_fixed_dist(self, line: LineString) -> NDArray: def interp_fixed_dist(self, line: LineString) -> NDArray:
''' Interpolate a line at fixed interval. ''' Interpolate a line at fixed interval.
Args: Args:
line (LineString): line line (LineString): line
Returns: Returns:
points (array): interpolated points, shape (N, 2) points (array): interpolated points, shape (N, 2)
''' '''
distances = list(np.arange(self.sample_dist, line.length, self.sample_dist)) distances = list(np.arange(self.sample_dist, line.length, self.sample_dist))
# make sure to sample at least two points when sample_dist > line.length # make sure to sample at least two points when sample_dist > line.length
distances = [0,] + distances + [line.length,] distances = [0, ] + distances + [line.length, ]
sampled_points = np.array([list(line.interpolate(distance).coords) sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze() for distance in distances]).squeeze()
return sampled_points return sampled_points
def get_vectorized_lines(self, map_geoms: Dict) -> Dict: def get_vectorized_lines(self, map_geoms: Dict) -> Dict:
''' Vectorize map elements. Iterate over the input dict and apply the ''' Vectorize map elements. Iterate over the input dict and apply the
specified sample funcion. specified sample funcion.
Args: Args:
line (LineString): line line (LineString): line
Returns: Returns:
vectors (array): dict of vectorized map elements. vectors (array): dict of vectorized map elements.
''' '''
...@@ -110,22 +112,22 @@ class VectorizeMap(object): ...@@ -110,22 +112,22 @@ class VectorizeMap(object):
elif geom.geom_type == 'Polygon': elif geom.geom_type == 'Polygon':
# polygon objects will not be vectorized # polygon objects will not be vectorized
continue continue
else: else:
raise ValueError('map geoms must be either LineString or Polygon!') raise ValueError('map geoms must be either LineString or Polygon!')
return vectors return vectors
def normalize_line(self, line: NDArray) -> NDArray: def normalize_line(self, line: NDArray) -> NDArray:
''' Convert points to range (0, 1). ''' Convert points to range (0, 1).
Args: Args:
line (LineString): line line (LineString): line
Returns: Returns:
normalized (array): normalized points. normalized (array): normalized points.
''' '''
origin = -np.array([self.roi_size[0]/2, self.roi_size[1]/2]) origin = -np.array([self.roi_size[0] / 2, self.roi_size[1] / 2])
line[:, :2] = line[:, :2] - origin line[:, :2] = line[:, :2] - origin
...@@ -134,7 +136,7 @@ class VectorizeMap(object): ...@@ -134,7 +136,7 @@ class VectorizeMap(object):
line[:, :2] = line[:, :2] / (self.roi_size + eps) line[:, :2] = line[:, :2] / (self.roi_size + eps)
return line return line
def __call__(self, input_dict): def __call__(self, input_dict):
map_geoms = input_dict['map_geoms'] map_geoms = input_dict['map_geoms']
...@@ -145,9 +147,9 @@ class VectorizeMap(object): ...@@ -145,9 +147,9 @@ class VectorizeMap(object):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(simplify={self.simplify}, ' repr_str += f'(simplify={self.simplify}, '
repr_str += f'sample_num={self.sample_num}), ' repr_str += f'sample_num={self.sample_num}), '
repr_str += f'sample_dist={self.sample_dist}), ' repr_str += f'sample_dist={self.sample_dist}), '
repr_str += f'roi_size={self.roi_size})' repr_str += f'roi_size={self.roi_size})'
repr_str += f'normalize={self.normalize})' repr_str += f'normalize={self.normalize})'
repr_str += f'coords_dim={self.coords_dim})' repr_str += f'coords_dim={self.coords_dim})'
return repr_str return repr_str
\ No newline at end of file
from .backbones import *
from .heads import *
from .losses import *
from .mapers import *
from .transformer_utils import *
from .assigner import *
from .assigner import HungarianLinesAssigner
from .match_cost import MapQueriesCost, BBoxLogitsCost, DynamicLinesCost, IoUCostC, BBoxCostC, LinesCost, LinesFixNumChamferCost, ClsSigmoidCost
import torch import torch
from mmdet.core.bbox.assigners import AssignResult, BaseAssigner
from mmdet.core.bbox.builder import BBOX_ASSIGNERS from mmdet.core.bbox.builder import BBOX_ASSIGNERS
from mmdet.core.bbox.assigners import AssignResult
from mmdet.core.bbox.assigners import BaseAssigner
from mmdet.core.bbox.match_costs import build_match_cost from mmdet.core.bbox.match_costs import build_match_cost
try: try:
...@@ -36,8 +34,8 @@ class HungarianLinesAssigner(BaseAssigner): ...@@ -36,8 +34,8 @@ class HungarianLinesAssigner(BaseAssigner):
type='MapQueriesCost', type='MapQueriesCost',
cls_cost=dict(type='ClassificationCost', weight=1.), cls_cost=dict(type='ClassificationCost', weight=1.),
reg_cost=dict(type='LinesCost', weight=1.0), reg_cost=dict(type='LinesCost', weight=1.0),
), ),
pc_range=None, pc_range=None,
**kwargs): **kwargs):
self.pc_range = pc_range self.pc_range = pc_range
...@@ -110,7 +108,8 @@ class HungarianLinesAssigner(BaseAssigner): ...@@ -110,7 +108,8 @@ class HungarianLinesAssigner(BaseAssigner):
matched_row_inds, matched_col_inds = linear_sum_assignment(cost) matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
except: except:
print('cost max{}, min{}'.format(cost.max(), cost.min())) print('cost max{}, min{}'.format(cost.max(), cost.min()))
import ipdb; ipdb.set_trace() import ipdb
ipdb.set_trace()
matched_row_inds = torch.from_numpy(matched_row_inds).to( matched_row_inds = torch.from_numpy(matched_row_inds).to(
preds['lines'].device) preds['lines'].device)
matched_col_inds = torch.from_numpy(matched_col_inds).to( matched_col_inds = torch.from_numpy(matched_col_inds).to(
...@@ -123,4 +122,4 @@ class HungarianLinesAssigner(BaseAssigner): ...@@ -123,4 +122,4 @@ class HungarianLinesAssigner(BaseAssigner):
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gts['labels'][matched_col_inds] assigned_labels[matched_row_inds] = gts['labels'][matched_col_inds]
return AssignResult( return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels) num_gts, assigned_gt_inds, None, labels=assigned_labels)
\ No newline at end of file
import torch import torch
from mmdet.core.bbox.match_costs.builder import MATCH_COST
from mmdet.core.bbox.match_costs import build_match_cost
from mmdet.core.bbox.iou_calculators import bbox_overlaps from mmdet.core.bbox.iou_calculators import bbox_overlaps
from mmdet.core.bbox.match_costs import build_match_cost
from mmdet.core.bbox.match_costs.builder import MATCH_COST
from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy
...@@ -83,7 +82,7 @@ class LinesFixNumChamferCost(object): ...@@ -83,7 +82,7 @@ class LinesFixNumChamferCost(object):
num_gts, num_bboxes = gt_lines.size(0), lines_pred.size(0) num_gts, num_bboxes = gt_lines.size(0), lines_pred.size(0)
dist_mat = lines_pred.new_full((num_bboxes, num_gts), dist_mat = lines_pred.new_full((num_bboxes, num_gts),
1.0,) 1.0, )
for i in range(num_bboxes): for i in range(num_bboxes):
for j in range(num_gts): for j in range(num_gts):
...@@ -212,6 +211,7 @@ class IoUCostC: ...@@ -212,6 +211,7 @@ class IoUCostC:
iou_cost = -overlaps iou_cost = -overlaps
return iou_cost * self.weight return iou_cost * self.weight
@MATCH_COST.register_module() @MATCH_COST.register_module()
class DynamicLinesCost(object): class DynamicLinesCost(object):
"""LinesL1Cost. """LinesL1Cost.
...@@ -273,7 +273,7 @@ class DynamicLinesCost(object): ...@@ -273,7 +273,7 @@ class DynamicLinesCost(object):
m1 = m1.unsqueeze(1).sigmoid() > 0.5 m1 = m1.unsqueeze(1).sigmoid() > 0.5
m2 = m2.unsqueeze(0) m2 = m2.unsqueeze(0)
valid_points_mask = (m1 + m2)/2. valid_points_mask = (m1 + m2) / 2.
average_factor_mask = valid_points_mask.sum(-1) > 0 average_factor_mask = valid_points_mask.sum(-1) > 0
average_factor = average_factor_mask.masked_fill( average_factor = average_factor_mask.masked_fill(
...@@ -360,8 +360,7 @@ class MapQueriesCost(object): ...@@ -360,8 +360,7 @@ class MapQueriesCost(object):
# Iou # Iou
if self.iou_cost is not None: if self.iou_cost is not None:
iou_cost = self.iou_cost(preds['lines'],gts['lines']) iou_cost = self.iou_cost(preds['lines'], gts['lines'])
cost += iou_cost cost += iou_cost
return cost return cost
...@@ -5,13 +5,13 @@ import torch.nn.functional as F ...@@ -5,13 +5,13 @@ import torch.nn.functional as F
class NoiseSythesis(nn.Module): class NoiseSythesis(nn.Module):
def __init__(self, def __init__(self,
p, scale=0.01, shift_scale=(8,5), p, scale=0.01, shift_scale=(8, 5),
scaling_size=(0.1,0.1), canvas_size=(200, 100), scaling_size=(0.1, 0.1), canvas_size=(200, 100),
bbox_type='sce', bbox_type='sce',
poly_coord_dim=2, poly_coord_dim=2,
bbox_coord_dim=2, bbox_coord_dim=2,
quantify=True): quantify=True):
super(NoiseSythesis, self).__init__() super(NoiseSythesis, self).__init__()
self.p = p self.p = p
...@@ -37,7 +37,7 @@ class NoiseSythesis(nn.Module): ...@@ -37,7 +37,7 @@ class NoiseSythesis(nn.Module):
dtype = bbox.dtype dtype = bbox.dtype
B = bbox.shape[0] B = bbox.shape[0]
noise = (torch.rand(B, device=device)*2-1)[:,None,None] # [-1,1] noise = (torch.rand(B, device=device) * 2 - 1)[:, None, None] # [-1,1]
scale = self.scaling_size.to(device) scale = self.scaling_size.to(device)
scale = (noise * scale) + 1 scale = (noise * scale) + 1
...@@ -45,7 +45,7 @@ class NoiseSythesis(nn.Module): ...@@ -45,7 +45,7 @@ class NoiseSythesis(nn.Module):
# recenterization # recenterization
coffset = scaled_bbox.mean(-2) - bbox.float().mean(-2) coffset = scaled_bbox.mean(-2) - bbox.float().mean(-2)
scaled_bbox = scaled_bbox - coffset[:,None] scaled_bbox = scaled_bbox - coffset[:, None]
return scaled_bbox.round().type(dtype) return scaled_bbox.round().type(dtype)
...@@ -60,13 +60,13 @@ class NoiseSythesis(nn.Module): ...@@ -60,13 +60,13 @@ class NoiseSythesis(nn.Module):
scale = (bbox.max(1)[0] - bbox.min(1)[0]) * 0.1 scale = (bbox.max(1)[0] - bbox.min(1)[0]) * 0.1
scale = torch.where(scale < shift_scale, scale, shift_scale) scale = torch.where(scale < shift_scale, scale, shift_scale)
noise = (torch.rand(batch_size, 2, device=device)*2-1) # [-1,1] noise = (torch.rand(batch_size, 2, device=device) * 2 - 1) # [-1,1]
offset = (noise * scale).round().type(bbox.dtype) offset = (noise * scale).round().type(bbox.dtype)
shifted_bbox = bbox + offset[:, None] shifted_bbox = bbox + offset[:, None]
return shifted_bbox return shifted_bbox
def gaussian_noise_bbox(self, bbox): def gaussian_noise_bbox(self, bbox):
dtype = bbox.dtype dtype = bbox.dtype
...@@ -80,23 +80,23 @@ class NoiseSythesis(nn.Module): ...@@ -80,23 +80,23 @@ class NoiseSythesis(nn.Module):
noisy_bbox = noisy_bbox.round().type(dtype) noisy_bbox = noisy_bbox.round().type(dtype)
# prevent out of bound case # prevent out of bound case
for i in range(self.bbox_coord_dim): for i in range(self.bbox_coord_dim):
noisy_bbox[...,i] =\ noisy_bbox[..., i] = \
torch.clamp(noisy_bbox[...,0],1,self.canvas_size[i]) torch.clamp(noisy_bbox[..., 0], 1, self.canvas_size[i])
else: else:
noisy_bbox = noisy_bbox.type(torch.float) noisy_bbox = noisy_bbox.type(torch.float)
return noisy_bbox return noisy_bbox
def gaussian_noise_poly(self, polyline, polyline_mask): def gaussian_noise_poly(self, polyline, polyline_mask):
device = polyline.device device = polyline.device
batchsize = polyline.shape[0] batchsize = polyline.shape[0]
scale = self.canvas_size * self.scale scale = self.canvas_size * self.scale
polyline = F.pad(polyline,(0,self.poly_coord_dim-1)) polyline = F.pad(polyline, (0, self.poly_coord_dim - 1))
polyline = polyline.view(batchsize,-1, self.poly_coord_dim) polyline = polyline.view(batchsize, -1, self.poly_coord_dim)
mask = F.pad(polyline_mask[:,1:],(0,self.poly_coord_dim)) mask = F.pad(polyline_mask[:, 1:], (0, self.poly_coord_dim))
noisy_polyline = torch.normal(polyline.type(torch.float), scale) noisy_polyline = torch.normal(polyline.type(torch.float), scale)
if self.quantify: if self.quantify:
...@@ -104,14 +104,14 @@ class NoiseSythesis(nn.Module): ...@@ -104,14 +104,14 @@ class NoiseSythesis(nn.Module):
# prevent out of bound case # prevent out of bound case
for i in range(self.poly_coord_dim): for i in range(self.poly_coord_dim):
noisy_polyline[...,i] =\ noisy_polyline[..., i] = \
torch.clamp(noisy_polyline[...,i],0,self.canvas_size[i]) torch.clamp(noisy_polyline[..., i], 0, self.canvas_size[i])
else: else:
noisy_polyline = noisy_polyline.type(torch.float) noisy_polyline = noisy_polyline.type(torch.float)
noisy_polyline = noisy_polyline.view(batchsize,-1) * mask noisy_polyline = noisy_polyline.view(batchsize, -1) * mask
noisy_polyline = noisy_polyline[:,:-(self.poly_coord_dim-1)] noisy_polyline = noisy_polyline[:, :-(self.poly_coord_dim - 1)]
return noisy_polyline return noisy_polyline
...@@ -125,11 +125,11 @@ class NoiseSythesis(nn.Module): ...@@ -125,11 +125,11 @@ class NoiseSythesis(nn.Module):
bbox = t(bbox) bbox = t(bbox)
# prevent out of bound case # prevent out of bound case
bbox[...,0] =\ bbox[..., 0] = \
torch.clamp(bbox[...,0],0,self.canvas_size[0]) torch.clamp(bbox[..., 0], 0, self.canvas_size[0])
bbox[...,1] =\ bbox[..., 1] = \
torch.clamp(bbox[...,1],0,self.canvas_size[1]) torch.clamp(bbox[..., 1], 0, self.canvas_size[1])
return bbox return bbox
...@@ -143,8 +143,8 @@ class NoiseSythesis(nn.Module): ...@@ -143,8 +143,8 @@ class NoiseSythesis(nn.Module):
bbox = self.gaussian_noise_bbox(bbox) bbox = self.gaussian_noise_bbox(bbox)
fbbox_aug = bbox.view(seq_len, -1) fbbox_aug = bbox.view(seq_len, -1)
aug_mask = torch.rand(fbbox.shape,device=fbbox.device) aug_mask = torch.rand(fbbox.shape, device=fbbox.device)
fbbox = torch.where(aug_mask<self.p, fbbox_aug, fbbox) fbbox = torch.where(aug_mask < self.p, fbbox_aug, fbbox)
elif self.bbox_type == 'rxyxy': elif self.bbox_type == 'rxyxy':
fbbox = self.rbbox_aug(batch) fbbox = self.rbbox_aug(batch)
elif self.bbox_type == 'convex_hull': elif self.bbox_type == 'convex_hull':
...@@ -154,18 +154,18 @@ class NoiseSythesis(nn.Module): ...@@ -154,18 +154,18 @@ class NoiseSythesis(nn.Module):
polyline = batch['polylines'] polyline = batch['polylines']
polyline_mask = batch['polyline_masks'] polyline_mask = batch['polyline_masks']
polyline_aug = self.gaussian_noise_poly(polyline, polyline_mask) polyline_aug = self.gaussian_noise_poly(polyline, polyline_mask)
aug_mask = torch.rand(polyline.shape,device=polyline.device) aug_mask = torch.rand(polyline.shape, device=polyline.device)
polyline = torch.where(aug_mask<self.p, polyline_aug, polyline) polyline = torch.where(aug_mask < self.p, polyline_aug, polyline)
return polyline, fbbox return polyline, fbbox
def rbbox_aug(self, batch): def rbbox_aug(self, batch):
return None return None
def convex_hull_aug(self,batch): def convex_hull_aug(self, batch):
return None return None
def __call__(self, batch, simple_aug=False): def __call__(self, batch, simple_aug=False):
...@@ -183,5 +183,4 @@ class NoiseSythesis(nn.Module): ...@@ -183,5 +183,4 @@ class NoiseSythesis(nn.Module):
aug_bbox_flat = aug_bbox.view(seq_len, -1) aug_bbox_flat = aug_bbox.view(seq_len, -1)
return aug_bbox_flat return aug_bbox_flat
from .ipm_backbone import IPMEncoder from .ipm_backbone import IPMEncoder
__all__ = [ __all__ = [
'IPMEncoder' 'IPMEncoder'
] ]
...@@ -4,17 +4,19 @@ ...@@ -4,17 +4,19 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
from collections import OrderedDict import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.models.layers import trunc_normal_, DropPath
from mmcv.runner import _load_checkpoint
from mmcv.cnn import constant_init, trunc_normal_init from mmcv.cnn import constant_init, trunc_normal_init
from mmcv.runner import _load_checkpoint
from mmdet.models.builder import BACKBONES
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from ops_dcnv3 import modules as opsm from ops_dcnv3 import modules as opsm
import torch.nn.functional as F from timm.models.layers import DropPath, trunc_normal_
from mmdet.models.builder import BACKBONES
class to_channels_first(nn.Module): class to_channels_first(nn.Module):
...@@ -84,7 +86,7 @@ class CrossAttention(nn.Module): ...@@ -84,7 +86,7 @@ class CrossAttention(nn.Module):
attn_head_dim (int, optional): Dimension of attention head. attn_head_dim (int, optional): Dimension of attention head.
out_dim (int, optional): Dimension of output. out_dim (int, optional): Dimension of output.
""" """
def __init__(self, def __init__(self,
dim, dim,
num_heads=8, num_heads=8,
...@@ -176,7 +178,7 @@ class AttentiveBlock(nn.Module): ...@@ -176,7 +178,7 @@ class AttentiveBlock(nn.Module):
attn_head_dim (int, optional): Dimension of attention head. Default: None. attn_head_dim (int, optional): Dimension of attention head. Default: None.
out_dim (int, optional): Dimension of output. Default: None. out_dim (int, optional): Dimension of output. Default: None.
""" """
def __init__(self, def __init__(self,
dim, dim,
num_heads, num_heads,
...@@ -185,7 +187,7 @@ class AttentiveBlock(nn.Module): ...@@ -185,7 +187,7 @@ class AttentiveBlock(nn.Module):
drop=0., drop=0.,
attn_drop=0., attn_drop=0.,
drop_path=0., drop_path=0.,
norm_layer="LN", norm_layer='LN',
attn_head_dim=None, attn_head_dim=None,
out_dim=None): out_dim=None):
super().__init__() super().__init__()
...@@ -361,9 +363,9 @@ class InternImageLayer(nn.Module): ...@@ -361,9 +363,9 @@ class InternImageLayer(nn.Module):
layer_scale=None, layer_scale=None,
offset_scale=1.0, offset_scale=1.0,
with_cp=False, with_cp=False,
dw_kernel_size=None, # for InternImage-H/G dw_kernel_size=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G center_feature_scale=False): # for InternImage-H/G
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.groups = groups self.groups = groups
...@@ -382,8 +384,8 @@ class InternImageLayer(nn.Module): ...@@ -382,8 +384,8 @@ class InternImageLayer(nn.Module):
offset_scale=offset_scale, offset_scale=offset_scale,
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
center_feature_scale=center_feature_scale) # for InternImage-H/G center_feature_scale=center_feature_scale) # for InternImage-H/G
self.drop_path = DropPath(drop_path) if drop_path > 0. \ self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity() else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN') self.norm2 = build_norm_layer(channels, 'LN')
...@@ -409,7 +411,7 @@ class InternImageLayer(nn.Module): ...@@ -409,7 +411,7 @@ class InternImageLayer(nn.Module):
if self.post_norm: if self.post_norm:
x = x + self.drop_path(self.norm1(self.dcn(x))) x = x + self.drop_path(self.norm1(self.dcn(x)))
x = x + self.drop_path(self.norm2(self.mlp(x))) x = x + self.drop_path(self.norm2(self.mlp(x)))
elif self.res_post_norm: # for InternImage-H/G elif self.res_post_norm: # for InternImage-H/G
x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x)))) x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x)))) x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
else: else:
...@@ -464,10 +466,10 @@ class InternImageBlock(nn.Module): ...@@ -464,10 +466,10 @@ class InternImageBlock(nn.Module):
offset_scale=1.0, offset_scale=1.0,
layer_scale=None, layer_scale=None,
with_cp=False, with_cp=False,
dw_kernel_size=None, # for InternImage-H/G dw_kernel_size=None, # for InternImage-H/G
post_norm_block_ids=None, # for InternImage-H/G post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G center_feature_scale=False): # for InternImage-H/G
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.depth = depth self.depth = depth
...@@ -489,15 +491,15 @@ class InternImageBlock(nn.Module): ...@@ -489,15 +491,15 @@ class InternImageBlock(nn.Module):
layer_scale=layer_scale, layer_scale=layer_scale,
offset_scale=offset_scale, offset_scale=offset_scale,
with_cp=with_cp, with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G center_feature_scale=center_feature_scale # for InternImage-H/G
) for i in range(depth) ) for i in range(depth)
]) ])
if not self.post_norm or center_feature_scale: if not self.post_norm or center_feature_scale:
self.norm = build_norm_layer(channels, 'LN') self.norm = build_norm_layer(channels, 'LN')
self.post_norm_block_ids = post_norm_block_ids self.post_norm_block_ids = post_norm_block_ids
if post_norm_block_ids is not None: # for InternImage-H/G if post_norm_block_ids is not None: # for InternImage-H/G
self.post_norms = nn.ModuleList( self.post_norms = nn.ModuleList(
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids] [build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
) )
...@@ -509,7 +511,7 @@ class InternImageBlock(nn.Module): ...@@ -509,7 +511,7 @@ class InternImageBlock(nn.Module):
x = blk(x) x = blk(x)
if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids): if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
index = self.post_norm_block_ids.index(i) index = self.post_norm_block_ids.index(i)
x = self.post_norms[index](x) # for InternImage-H/G x = self.post_norms[index](x) # for InternImage-H/G
if not self.post_norm or self.center_feature_scale: if not self.post_norm or self.center_feature_scale:
x = self.norm(x) x = self.norm(x)
if return_wo_downsample: if return_wo_downsample:
...@@ -575,7 +577,7 @@ class InternImage(nn.Module): ...@@ -575,7 +577,7 @@ class InternImage(nn.Module):
self.num_levels = len(depths) self.num_levels = len(depths)
self.depths = depths self.depths = depths
self.channels = channels self.channels = channels
self.num_features = int(channels * 2**(self.num_levels - 1)) self.num_features = int(channels * 2 ** (self.num_levels - 1))
self.post_norm = post_norm self.post_norm = post_norm
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.init_cfg = init_cfg self.init_cfg = init_cfg
...@@ -607,10 +609,10 @@ class InternImage(nn.Module): ...@@ -607,10 +609,10 @@ class InternImage(nn.Module):
self.levels = nn.ModuleList() self.levels = nn.ModuleList()
for i in range(self.num_levels): for i in range(self.num_levels):
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and ( post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
i == 2) else None # for InternImage-H/G i == 2) else None # for InternImage-H/G
level = InternImageBlock( level = InternImageBlock(
core_op=getattr(opsm, core_op), core_op=getattr(opsm, core_op),
channels=int(channels * 2**i), channels=int(channels * 2 ** i),
depth=depths[i], depth=depths[i],
groups=groups[i], groups=groups[i],
mlp_ratio=self.mlp_ratio, mlp_ratio=self.mlp_ratio,
...@@ -624,9 +626,9 @@ class InternImage(nn.Module): ...@@ -624,9 +626,9 @@ class InternImage(nn.Module):
offset_scale=offset_scale, offset_scale=offset_scale,
with_cp=with_cp, with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G center_feature_scale=center_feature_scale # for InternImage-H/G
) )
self.levels.append(level) self.levels.append(level)
...@@ -697,4 +699,4 @@ class InternImage(nn.Module): ...@@ -697,4 +699,4 @@ class InternImage(nn.Module):
x, x_ = level(x, return_wo_downsample=True) x, x_ = level(x, return_wo_downsample=True)
if level_idx in self.out_indices: if level_idx in self.out_indices:
seq_out.append(x_.permute(0, 3, 1, 2).contiguous()) seq_out.append(x_.permute(0, 3, 1, 2).contiguous())
return seq_out return seq_out
\ No newline at end of file
import copy import copy
import math import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmdet3d.models.builder import BACKBONES from mmdet3d.models.builder import BACKBONES
from mmdet.models import build_backbone, build_neck from mmdet.models import build_backbone, build_neck
class UpsampleBlock(nn.Module): class UpsampleBlock(nn.Module):
def __init__(self, ins, outs): def __init__(self, ins, outs):
super(UpsampleBlock, self).__init__() super(UpsampleBlock, self).__init__()
...@@ -17,7 +18,6 @@ class UpsampleBlock(nn.Module): ...@@ -17,7 +18,6 @@ class UpsampleBlock(nn.Module):
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = self.relu(self.gn(x)) x = self.relu(self.gn(x))
x = self.upsample2x(x) x = self.upsample2x(x)
...@@ -26,7 +26,7 @@ class UpsampleBlock(nn.Module): ...@@ -26,7 +26,7 @@ class UpsampleBlock(nn.Module):
def upsample2x(self, x): def upsample2x(self, x):
_, _, h, w = x.shape _, _, h, w = x.shape
x = F.interpolate(x, size=(h*2, w*2), x = F.interpolate(x, size=(h * 2, w * 2),
mode='bilinear', align_corners=True) mode='bilinear', align_corners=True)
return x return x
...@@ -54,7 +54,7 @@ class Upsample(nn.Module): ...@@ -54,7 +54,7 @@ class Upsample(nn.Module):
continue continue
tmp = [copy.deepcopy(input_conv), ] tmp = [copy.deepcopy(input_conv), ]
tmp += [copy.deepcopy(inter_conv) for i in range(layer_num-1)] tmp += [copy.deepcopy(inter_conv) for i in range(layer_num - 1)]
fscale.append(nn.Sequential(*tmp)) fscale.append(nn.Sequential(*tmp))
self.fscale = nn.ModuleList(fscale) self.fscale = nn.ModuleList(fscale)
...@@ -117,21 +117,21 @@ class IPMEncoder(nn.Module): ...@@ -117,21 +117,21 @@ class IPMEncoder(nn.Module):
if self.use_lidar: if self.use_lidar:
self.pp = PointPillarEncoder(lidar_dim, xbound, ybound, zbound) self.pp = PointPillarEncoder(lidar_dim, xbound, ybound, zbound)
self.outconvs =\ self.outconvs = \
nn.Conv2d((self.upsample.out_channels+3)*len(heights), out_channels//2, nn.Conv2d((self.upsample.out_channels + 3) * len(heights), out_channels // 2,
kernel_size=3, stride=1, padding=1) # same kernel_size=3, stride=1, padding=1) # same
if self.use_image: if self.use_image:
_out_channels = out_channels//2 _out_channels = out_channels // 2
else: else:
_out_channels = out_channels _out_channels = out_channels
self.outconvs_lidar =\ self.outconvs_lidar = \
nn.Conv2d(lidar_dim, _out_channels, nn.Conv2d(lidar_dim, _out_channels,
kernel_size=3, stride=1, padding=1) # same kernel_size=3, stride=1, padding=1) # same
else: else:
self.outconvs =\ self.outconvs = \
nn.Conv2d((self.upsample.out_channels+3)*len(heights), out_channels, nn.Conv2d((self.upsample.out_channels + 3) * len(heights), out_channels,
kernel_size=3, stride=1, padding=1) # same kernel_size=3, stride=1, padding=1) # same
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
...@@ -139,11 +139,10 @@ class IPMEncoder(nn.Module): ...@@ -139,11 +139,10 @@ class IPMEncoder(nn.Module):
bev_planes = [construct_plane_grid( bev_planes = [construct_plane_grid(
xbound, ybound, h) for h in self.heights] xbound, ybound, h) for h in self.heights]
self.register_buffer('bev_planes', torch.stack( self.register_buffer('bev_planes', torch.stack(
bev_planes),) # nlvl,bH,bW,2 bev_planes), ) # nlvl,bH,bW,2
self.masked_embeds = nn.Embedding(len(heights), out_channels) self.masked_embeds = nn.Embedding(len(heights), out_channels)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
"""Initialize model weights.""" """Initialize model weights."""
...@@ -154,12 +153,12 @@ class IPMEncoder(nn.Module): ...@@ -154,12 +153,12 @@ class IPMEncoder(nn.Module):
for p in self.outconvs.parameters(): for p in self.outconvs.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
if self.use_lidar: if self.use_lidar:
for p in self.outconvs_lidar.parameters(): for p in self.outconvs_lidar.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
for p in self.pp.parameters(): for p in self.pp.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
...@@ -169,7 +168,7 @@ class IPMEncoder(nn.Module): ...@@ -169,7 +168,7 @@ class IPMEncoder(nn.Module):
Extract image feaftures and sum up into one pic Extract image feaftures and sum up into one pic
Args: Args:
imgs: B, n_cam, C, iH, iW imgs: B, n_cam, C, iH, iW
Returns: Returns:
img_feat: B * n_cam, C, H, W img_feat: B * n_cam, C, H, W
''' '''
...@@ -188,12 +187,12 @@ class IPMEncoder(nn.Module): ...@@ -188,12 +187,12 @@ class IPMEncoder(nn.Module):
def forward(self, imgs, img_metas, *args, points=None, **kwargs): def forward(self, imgs, img_metas, *args, points=None, **kwargs):
''' '''
Args: Args:
imgs: torch.Tensor of shape [B, N, 3, H, W] imgs: torch.Tensor of shape [B, N, 3, H, W]
N: number of cams N: number of cams
img_metas: img_metas:
# N=6, ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'] # N=6, ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
ego2cam: [B, N, 4, 4] ego2cam: [B, N, 4, 4]
cam_intrinsics: [B, N, 3, 3] cam_intrinsics: [B, N, 3, 3]
cam2ego_rotations: [B, N, 3, 3] cam2ego_rotations: [B, N, 3, 3]
cam2ego_translations: [B, N, 3] cam2ego_translations: [B, N, 3]
...@@ -225,7 +224,7 @@ class IPMEncoder(nn.Module): ...@@ -225,7 +224,7 @@ class IPMEncoder(nn.Module):
if self.use_lidar: if self.use_lidar:
lidar_feat = self.get_lidar_feature(points) lidar_feat = self.get_lidar_feature(points)
if self.use_image: if self.use_image:
bev_feat = torch.cat([bev_feat,lidar_feat],dim=1) bev_feat = torch.cat([bev_feat, lidar_feat], dim=1)
else: else:
bev_feat = lidar_feat bev_feat = lidar_feat
...@@ -233,7 +232,7 @@ class IPMEncoder(nn.Module): ...@@ -233,7 +232,7 @@ class IPMEncoder(nn.Module):
def ipm(self, cam_feat, ego2cam, img_shape): def ipm(self, cam_feat, ego2cam, img_shape):
''' '''
inverse project inverse project
Args: Args:
cam_feat: B*ncam, C, cH, cW cam_feat: B*ncam, C, cH, cW
img_shape: tuple(H, W) img_shape: tuple(H, W)
...@@ -250,7 +249,7 @@ class IPMEncoder(nn.Module): ...@@ -250,7 +249,7 @@ class IPMEncoder(nn.Module):
# bev_grid_pos: B*ncam, nlvl*bH*bW, 2 # bev_grid_pos: B*ncam, nlvl*bH*bW, 2
bev_grid_pos, bev_cam_mask = get_campos(bev_grid, ego2cam, img_shape) bev_grid_pos, bev_cam_mask = get_campos(bev_grid, ego2cam, img_shape)
# B*cam, nlvl*bH, bW, 2 # B*cam, nlvl*bH, bW, 2
bev_grid_pos = bev_grid_pos.unflatten(-2, (nlvl*bH, bW)) bev_grid_pos = bev_grid_pos.unflatten(-2, (nlvl * bH, bW))
# project feat from 2D to bev plane # project feat from 2D to bev plane
projected_feature = F.grid_sample( projected_feature = F.grid_sample(
...@@ -262,11 +261,11 @@ class IPMEncoder(nn.Module): ...@@ -262,11 +261,11 @@ class IPMEncoder(nn.Module):
# eliminate the ncam # eliminate the ncam
# The bev feature is the sum of the 6 cameras # The bev feature is the sum of the 6 cameras
bev_feat_mask = bev_feat_mask.unsqueeze(2) bev_feat_mask = bev_feat_mask.unsqueeze(2)
projected_feature = (projected_feature*bev_feat_mask).sum(1) projected_feature = (projected_feature * bev_feat_mask).sum(1)
num_feat = bev_feat_mask.sum(1) num_feat = bev_feat_mask.sum(1)
projected_feature = projected_feature / \ projected_feature = projected_feature / \
num_feat.masked_fill(num_feat == 0, 1) num_feat.masked_fill(num_feat == 0, 1)
# concatenate a position information # concatenate a position information
# projected_feature: B, bH, bW, nlvl, C+3 # projected_feature: B, bH, bW, nlvl, C+3
...@@ -287,7 +286,7 @@ class IPMEncoder(nn.Module): ...@@ -287,7 +286,7 @@ class IPMEncoder(nn.Module):
# bev_grid = bev_grid.permute(0, 3, 1, 2) # bev_grid = bev_grid.permute(0, 3, 1, 2)
# lidar_feature = torch.cat( # lidar_feature = torch.cat(
# (lidar_feature, bev_grid), dim=1) # (lidar_feature, bev_grid), dim=1)
lidar_feature = self.outconvs_lidar(lidar_feature) lidar_feature = self.outconvs_lidar(lidar_feature)
return lidar_feature return lidar_feature
...@@ -321,7 +320,7 @@ def construct_plane_grid(xbound, ybound, height: float, dtype=torch.float32): ...@@ -321,7 +320,7 @@ def construct_plane_grid(xbound, ybound, height: float, dtype=torch.float32):
def get_campos(reference_points, ego2cam, img_shape): def get_campos(reference_points, ego2cam, img_shape):
''' '''
Find the each refence point's corresponding pixel in each camera Find the each refence point's corresponding pixel in each camera
Args: Args:
reference_points: [B, num_query, 3] reference_points: [B, num_query, 3]
ego2cam: (B, num_cam, 4, 4) ego2cam: (B, num_cam, 4, 4)
Outs: Outs:
...@@ -351,7 +350,7 @@ def get_campos(reference_points, ego2cam, img_shape): ...@@ -351,7 +350,7 @@ def get_campos(reference_points, ego2cam, img_shape):
eps = 1e-9 eps = 1e-9
mask = (reference_points_cam[..., 2:3] > eps) mask = (reference_points_cam[..., 2:3] > eps)
reference_points_cam =\ reference_points_cam = \
reference_points_cam[..., 0:2] / \ reference_points_cam[..., 0:2] / \
reference_points_cam[..., 2:3] + eps reference_points_cam[..., 2:3] + eps
...@@ -362,13 +361,13 @@ def get_campos(reference_points, ego2cam, img_shape): ...@@ -362,13 +361,13 @@ def get_campos(reference_points, ego2cam, img_shape):
reference_points_cam = (reference_points_cam - 0.5) * 2 reference_points_cam = (reference_points_cam - 0.5) * 2
mask = (mask & (reference_points_cam[..., 0:1] > -1.0) mask = (mask & (reference_points_cam[..., 0:1] > -1.0)
& (reference_points_cam[..., 0:1] < 1.0) & (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 1:2] > -1.0) & (reference_points_cam[..., 1:2] > -1.0)
& (reference_points_cam[..., 1:2] < 1.0)) & (reference_points_cam[..., 1:2] < 1.0))
# (B, num_cam, num_query) # (B, num_cam, num_query)
mask = mask.view(B, num_cam, num_query) mask = mask.view(B, num_cam, num_query)
reference_points_cam = reference_points_cam.view(B*num_cam, num_query, 2) reference_points_cam = reference_points_cam.view(B * num_cam, num_query, 2)
return reference_points_cam, mask return reference_points_cam, mask
......
from .base_map_head import BaseMapHead
from .dg_head import DGHead
from .map_element_detector import MapElementDetector
from .polyline_generator import PolylineGenerator
\ No newline at end of file
...@@ -3,7 +3,6 @@ from abc import ABCMeta, abstractmethod ...@@ -3,7 +3,6 @@ from abc import ABCMeta, abstractmethod
import torch.nn as nn import torch.nn as nn
from mmcv.runner import auto_fp16 from mmcv.runner import auto_fp16
from mmcv.utils import print_log from mmcv.utils import print_log
from mmdet.utils import get_root_logger from mmdet.utils import get_root_logger
...@@ -24,10 +23,10 @@ class BaseMapHead(nn.Module, metaclass=ABCMeta): ...@@ -24,10 +23,10 @@ class BaseMapHead(nn.Module, metaclass=ABCMeta):
logger = get_root_logger() logger = get_root_logger()
print_log(f'load model from: {pretrained}', logger=logger) print_log(f'load model from: {pretrained}', logger=logger)
@auto_fp16(apply_to=('img', )) @auto_fp16(apply_to=('img',))
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
pass pass
@abstractmethod @abstractmethod
def loss(self, pred, gt): def loss(self, pred, gt):
''' '''
...@@ -42,7 +41,7 @@ class BaseMapHead(nn.Module, metaclass=ABCMeta): ...@@ -42,7 +41,7 @@ class BaseMapHead(nn.Module, metaclass=ABCMeta):
) )
''' '''
return return
@abstractmethod @abstractmethod
def post_process(self, pred): def post_process(self, pred):
''' '''
......
# the causal layer is credited by the https://github.com/alexmt-scale/causal-transformer-decoder # the causal layer is credited by the https://github.com/alexmt-scale/causal-transformer-decoder
# we made some change to stick with the polygen. # we made some change to stick with the polygen.
import torch
import torch.nn as nn
from typing import Optional from typing import Optional
from torch import Tensor
import torch
import torch.nn as nn
from mmcv.cnn.bricks.registry import ATTENTION from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from torch import Tensor
def build_attention(cfg, default_args=None): def build_attention(cfg, default_args=None):
...@@ -29,14 +29,14 @@ class CausalTransformerDecoder(nn.TransformerDecoder): ...@@ -29,14 +29,14 @@ class CausalTransformerDecoder(nn.TransformerDecoder):
""" """
def forward( def forward(
self, self,
tgt: Tensor, tgt: Tensor,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
cache: Optional[Tensor] = None, cache: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
causal_mask: Optional[Tensor] = None, causal_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
""" """
Args: Args:
...@@ -58,7 +58,7 @@ class CausalTransformerDecoder(nn.TransformerDecoder): ...@@ -58,7 +58,7 @@ class CausalTransformerDecoder(nn.TransformerDecoder):
if self.training: if self.training:
if cache is not None: if cache is not None:
raise ValueError( raise ValueError(
"cache parameter should be None in training mode") 'cache parameter should be None in training mode')
for mod in self.layers: for mod in self.layers:
output = mod( output = mod(
output, output,
...@@ -132,7 +132,7 @@ class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer): ...@@ -132,7 +132,7 @@ class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer):
""" """
Args: Args:
see CausalTransformerDecoder see CausalTransformerDecoder
query is not None model will perform query stream query is not None model will perform query stream
Returns: Returns:
Tensor: Tensor:
If training: embedding of the whole layer: seq_len x bsz x hidden_dim If training: embedding of the whole layer: seq_len x bsz x hidden_dim
...@@ -140,23 +140,23 @@ class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer): ...@@ -140,23 +140,23 @@ class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer):
""" """
if not self.norm_first: if not self.norm_first:
raise ValueError( raise ValueError(
"norm_first parameter should be True!") 'norm_first parameter should be True!')
if self.training: if self.training:
# the official Pytorch implementation # the official Pytorch implementation
x = tgt x = tgt
if query is not None: if query is not None:
x = query x = query
x = x + self.res_weight1 * \ x = x + self.res_weight1 * \
self._sa_block(self.norm1(x), self.norm1(tgt), causal_mask, self._sa_block(self.norm1(x), self.norm1(tgt), causal_mask,
tgt_key_padding_mask) tgt_key_padding_mask)
if memory is not None: if memory is not None:
x = x + self.res_weight2 * \ x = x + self.res_weight2 * \
self._mha_block(self.norm2(x), memory, self._mha_block(self.norm2(x), memory,
memory_mask, memory_key_padding_mask) memory_mask, memory_key_padding_mask)
x = x + self.res_weight3*self._ff_block(self.norm3(x)) x = x + self.res_weight3 * self._ff_block(self.norm3(x))
return x return x
# This part is adapted from the official Pytorch implementation # This part is adapted from the official Pytorch implementation
...@@ -169,14 +169,14 @@ class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer): ...@@ -169,14 +169,14 @@ class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer):
if only_last: if only_last:
x = x[-1:] x = x[-1:]
if causal_mask is not None: if causal_mask is not None:
attn_mask = causal_mask attn_mask = causal_mask
if only_last: if only_last:
attn_mask = attn_mask[-1:] # XXX attn_mask = attn_mask[-1:] # XXX
else: else:
attn_mask = None attn_mask = None
# efficient self attention # efficient self attention
x = x + self.res_weight1 * \ x = x + self.res_weight1 * \
self._sa_block(self.norm1(x), self.norm1(tgt), attn_mask, self._sa_block(self.norm1(x), self.norm1(tgt), attn_mask,
...@@ -189,7 +189,7 @@ class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer): ...@@ -189,7 +189,7 @@ class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer):
memory_mask, memory_key_padding_mask) memory_mask, memory_key_padding_mask)
# final feed-forward network # final feed-forward network
x = x + self.res_weight3*self._ff_block(self.norm3(x)) x = x + self.res_weight3 * self._ff_block(self.norm3(x))
return x return x
...@@ -235,7 +235,8 @@ class PolygenTransformerEncoderLayer(nn.TransformerEncoderLayer): ...@@ -235,7 +235,8 @@ class PolygenTransformerEncoderLayer(nn.TransformerEncoderLayer):
self.norm_first = norm_first self.norm_first = norm_first
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
r"""Pass the input through the encoder layer. r"""Pass the input through the encoder layer.
Args: Args:
src: the sequence to the encoder layer (required). src: the sequence to the encoder layer (required).
...@@ -249,13 +250,13 @@ class PolygenTransformerEncoderLayer(nn.TransformerEncoderLayer): ...@@ -249,13 +250,13 @@ class PolygenTransformerEncoderLayer(nn.TransformerEncoderLayer):
x = src x = src
if self.norm_first: if self.norm_first:
x = x + self.res_weight1*self._sa_block(self.norm1(x), src_mask, x = x + self.res_weight1 * self._sa_block(self.norm1(x), src_mask,
src_key_padding_mask) src_key_padding_mask)
x = x + self.res_weight2*self._ff_block(self.norm2(x)) x = x + self.res_weight2 * self._ff_block(self.norm2(x))
else: else:
x = self.norm1( x = self.norm1(
x + self.res_weight1*self._sa_block(x, src_mask, src_key_padding_mask)) x + self.res_weight1 * self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self.res_weight2*self._ff_block(x)) x = self.norm2(x + self.res_weight2 * self._ff_block(x))
return x return x
...@@ -274,12 +275,12 @@ class PolygenTransformerEncoderLayer(nn.TransformerEncoderLayer): ...@@ -274,12 +275,12 @@ class PolygenTransformerEncoderLayer(nn.TransformerEncoderLayer):
return self.dropout2(x) return self.dropout2(x)
def generate_square_subsequent_mask(sz: int, device: str = "cpu") -> torch.Tensor: def generate_square_subsequent_mask(sz: int, device: str = 'cpu') -> torch.Tensor:
""" Generate the attention mask for causal decoding """ """ Generate the attention mask for causal decoding """
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = ( mask = (
mask.float() mask.float()
.masked_fill(mask == 0, float("-inf")) .masked_fill(mask == 0, float('-inf'))
.masked_fill(mask == 1, float(0.0)) .masked_fill(mask == 1, float(0.0))
).to(device=device) ).to(device=device)
return mask return mask
\ No newline at end of file
...@@ -2,18 +2,20 @@ import torch ...@@ -2,18 +2,20 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
def generate_square_subsequent_mask(sz: int, condition_len: int = 1, bool_out=False, device: str = "cpu") -> torch.Tensor:
def generate_square_subsequent_mask(sz: int, condition_len: int = 1, bool_out=False,
device: str = 'cpu') -> torch.Tensor:
""" Generate the attention mask for causal decoding """ """ Generate the attention mask for causal decoding """
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
if condition_len > 1: if condition_len > 1:
mask[:condition_len,:condition_len] = 1 mask[:condition_len, :condition_len] = 1
if not bool_out: if not bool_out:
mask = ( mask = (
mask.float() mask.float()
.masked_fill(mask == 0, float("-inf")) .masked_fill(mask == 0, float('-inf'))
.masked_fill(mask == 1, float(0.0))) .masked_fill(mask == 1, float(0.0)))
return mask.to(device=device) return mask.to(device=device)
...@@ -39,10 +41,10 @@ def quantize_verts( ...@@ -39,10 +41,10 @@ def quantize_verts(
""" """
min_range = -1 min_range = -1
max_range = 1 max_range = 1
range_quantize = canvas_size-1 range_quantize = canvas_size - 1
verts_ratio = (verts - min_range) / ( verts_ratio = (verts - min_range) / (
max_range - min_range) max_range - min_range)
verts_quantize = verts_ratio * range_quantize verts_quantize = verts_ratio * range_quantize
return verts_quantize.type(torch.int32) return verts_quantize.type(torch.int32)
...@@ -56,7 +58,7 @@ def top_k_logits(logits, k): ...@@ -56,7 +58,7 @@ def top_k_logits(logits, k):
values, _ = torch.topk(logits, k=k) values, _ = torch.topk(logits, k=k)
k_largest = torch.min(values) k_largest = torch.min(values)
logits = torch.where(logits < k_largest, logits = torch.where(logits < k_largest,
torch.ones_like(logits)*-1e9, logits) torch.ones_like(logits) * -1e9, logits)
return logits return logits
......
import copy import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import Conv2d, Linear from mmcv.cnn import Linear
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from torch.distributions.categorical import Categorical
from mmdet.core import multi_apply, reduce_mean from mmdet.core import multi_apply, reduce_mean
from mmdet.models import HEADS from mmdet.models import HEADS
from torch.distributions.categorical import Categorical
from .detr_head import DETRMapFixedNumHead from .detr_head import DETRMapFixedNumHead
@HEADS.register_module(force=True) @HEADS.register_module(force=True)
class DETRBboxHead(DETRMapFixedNumHead): class DETRBboxHead(DETRMapFixedNumHead):
def __init__(self, *args, canvas_size=(400, 200), discrete_output=True, separate_detect=True, def __init__(self, *args, canvas_size=(400, 200), discrete_output=True, separate_detect=True,
mode='xyxy', bbox_size=None, coord_dim=2, kp_coord_dim=2, mode='xyxy', bbox_size=None, coord_dim=2, kp_coord_dim=2,
**kwargs): **kwargs):
self.canvas_size = canvas_size # hard code self.canvas_size = canvas_size # hard code
self.separate_detect = separate_detect self.separate_detect = separate_detect
self.discrete_output = discrete_output self.discrete_output = discrete_output
self.bbox_size = 3 if mode=='sce' else 2 self.bbox_size = 3 if mode == 'sce' else 2
if bbox_size is not None: if bbox_size is not None:
self.bbox_size = bbox_size self.bbox_size = bbox_size
self.coord_dim = coord_dim # for xyz self.coord_dim = coord_dim # for xyz
...@@ -31,7 +32,7 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -31,7 +32,7 @@ class DETRBboxHead(DETRMapFixedNumHead):
del self.canvas_size del self.canvas_size
self.register_buffer('canvas_size', torch.tensor(canvas_size)) self.register_buffer('canvas_size', torch.tensor(canvas_size))
self._init_embedding() self._init_embedding()
def _init_embedding(self): def _init_embedding(self):
# for bbox parameter xstart, ystart, xend, yend # for bbox parameter xstart, ystart, xend, yend
...@@ -42,12 +43,12 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -42,12 +43,12 @@ class DETRBboxHead(DETRMapFixedNumHead):
self.img_coord_embed = nn.Linear(2, self.embed_dims) self.img_coord_embed = nn.Linear(2, self.embed_dims)
def _init_branch(self,): def _init_branch(self, ):
"""Initialize classification branch and regression branch of head.""" """Initialize classification branch and regression branch of head."""
# add sigmoid or not # add sigmoid or not
if self.separate_detect: if self.separate_detect:
if self.cls_out_channels == self.num_classes+1: if self.cls_out_channels == self.num_classes + 1:
self.cls_out_channels = 2 self.cls_out_channels = 2
else: else:
self.cls_out_channels = 1 self.cls_out_channels = 1
...@@ -62,10 +63,10 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -62,10 +63,10 @@ class DETRBboxHead(DETRMapFixedNumHead):
if self.discrete_output: if self.discrete_output:
reg_branch.append(nn.Linear( reg_branch.append(nn.Linear(
self.embed_dims, max(self.canvas_size), bias=True,)) self.embed_dims, max(self.canvas_size), bias=True, ))
else: else:
reg_branch.append(nn.Linear( reg_branch.append(nn.Linear(
self.embed_dims, self.bbox_size*self.coord_dim, bias=True,)) self.embed_dims, self.bbox_size * self.coord_dim, bias=True, ))
reg_branch = nn.Sequential(*reg_branch) reg_branch = nn.Sequential(*reg_branch)
...@@ -133,12 +134,12 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -133,12 +134,12 @@ class DETRBboxHead(DETRMapFixedNumHead):
[nb_dec, bs, num_query, num_points, 2]. [nb_dec, bs, num_query, num_points, 2].
''' '''
(global_context_embedding, sequential_context_embeddings) =\ (global_context_embedding, sequential_context_embeddings) = \
self._prepare_context(batch, context) self._prepare_context(batch, context)
if self.separate_detect: if self.separate_detect:
query_embedding = self.query_embedding.weight[None] + \ query_embedding = self.query_embedding.weight[None] + \
global_context_embedding[:, None] global_context_embedding[:, None]
else: else:
B = sequential_context_embeddings.shape[0] B = sequential_context_embeddings.shape[0]
query_embedding = self.query_embedding.weight[None].repeat(B, 1, 1) query_embedding = self.query_embedding.weight[None].repeat(B, 1, 1)
...@@ -166,18 +167,18 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -166,18 +167,18 @@ class DETRBboxHead(DETRMapFixedNumHead):
pos = [] pos = []
for i in range(4): for i in range(4):
pos_embeds = self.bbox_embedding.weight[i] pos_embeds = self.bbox_embedding.weight[i]
_pos = self.pre_branches['reg'](query_feat+pos_embeds) _pos = self.pre_branches['reg'](query_feat + pos_embeds)
pos.append(_pos) pos.append(_pos)
# # y mask # # y mask
# _vert_mask = torch.arange(logits.shape[-1], device=logits.device) # _vert_mask = torch.arange(logits.shape[-1], device=logits.device)
# vertices_mask_y = (_vert_mask < self.canvas_size[1]+1) # vertices_mask_y = (_vert_mask < self.canvas_size[1]+1)
# logits[:,1::2] = logits[:,1::2]*vertices_mask_y - ~vertices_mask_y*1e9 # logits[:,1::2] = logits[:,1::2]*vertices_mask_y - ~vertices_mask_y*1e9
logits = torch.stack(pos, dim=-2)/1. logits = torch.stack(pos, dim=-2) / 1.
lines = Categorical(logits=logits) lines = Categorical(logits=logits)
else: else:
lines = self.pre_branches['reg'](query_feat).sigmoid() lines = self.pre_branches['reg'](query_feat).sigmoid()
lines = lines.unflatten(-1, (self.bbox_size, self.coord_dim))*self.canvas_size lines = lines.unflatten(-1, (self.bbox_size, self.coord_dim)) * self.canvas_size
lines = lines.flatten(-2) lines = lines.flatten(-2)
return dict( return dict(
...@@ -220,7 +221,7 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -220,7 +221,7 @@ class DETRBboxHead(DETRMapFixedNumHead):
num_pred_lines = len(lines_pred) num_pred_lines = len(lines_pred)
# assigner and sampler # assigner and sampler
assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,), assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred, ),
gts=dict(lines=gt_lines, gts=dict(lines=gt_lines,
labels=gt_labels, ), labels=gt_labels, ),
gt_bboxes_ignore=gt_bboxes_ignore) gt_bboxes_ignore=gt_bboxes_ignore)
...@@ -232,10 +233,10 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -232,10 +233,10 @@ class DETRBboxHead(DETRMapFixedNumHead):
# label targets 0: foreground, 1: background # label targets 0: foreground, 1: background
if self.separate_detect: if self.separate_detect:
labels = gt_lines.new_full((num_pred_lines, ), 1, dtype=torch.long) labels = gt_lines.new_full((num_pred_lines,), 1, dtype=torch.long)
else: else:
labels = gt_lines.new_full( labels = gt_lines.new_full(
(num_pred_lines, ), self.num_classes, dtype=torch.long) (num_pred_lines,), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_lines.new_ones(num_pred_lines) label_weights = gt_lines.new_ones(num_pred_lines)
...@@ -308,11 +309,11 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -308,11 +309,11 @@ class DETRBboxHead(DETRMapFixedNumHead):
(labels_list, label_weights_list, (labels_list, label_weights_list,
lines_targets_list, lines_weights_list, lines_targets_list, lines_weights_list,
pos_inds_list, neg_inds_list,pos_gt_inds_list) = multi_apply( pos_inds_list, neg_inds_list, pos_gt_inds_list) = multi_apply(
self._get_target_single, self._get_target_single,
preds['scores'], lines_pred, preds['scores'], lines_pred,
class_label, bbox, class_label, bbox,
gt_bboxes_ignore=gt_bboxes_ignore_list) gt_bboxes_ignore=gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list))
...@@ -351,7 +352,7 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -351,7 +352,7 @@ class DETRBboxHead(DETRMapFixedNumHead):
""" """
# Get target for each sample # Get target for each sample
new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list =\ new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list = \
self.get_targets(preds, gts, gt_bboxes_ignore_list) self.get_targets(preds, gts, gt_bboxes_ignore_list)
# Batched all data # Batched all data
...@@ -360,7 +361,7 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -360,7 +361,7 @@ class DETRBboxHead(DETRMapFixedNumHead):
# construct weighted avg_factor to match with the official DETR repo # construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \ cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor: if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean( cls_avg_factor = reduce_mean(
preds['scores'].new_tensor([cls_avg_factor])) preds['scores'].new_tensor([cls_avg_factor]))
...@@ -386,7 +387,7 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -386,7 +387,7 @@ class DETRBboxHead(DETRMapFixedNumHead):
# position NLL loss # position NLL loss
if self.discrete_output: if self.discrete_output:
loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) * loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) *
new_gts['bboxs_weights']).sum()/(num_total_pos) new_gts['bboxs_weights']).sum() / (num_total_pos)
else: else:
loss_reg = self.reg_loss( loss_reg = self.reg_loss(
preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos) preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos)
...@@ -408,9 +409,9 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -408,9 +409,9 @@ class DETRBboxHead(DETRMapFixedNumHead):
pos_msk = label == 0 pos_msk = label == 0
neg_msk = ~pos_msk neg_msk = ~pos_msk
loss_cls = -(p.log()*pos_msk + (1-p).log()*neg_msk) loss_cls = -(p.log() * pos_msk + (1 - p).log() * neg_msk)
loss_cls = (loss_cls * weights).sum()/cls_avg_factor loss_cls = (loss_cls * weights).sum() / cls_avg_factor
return loss_cls return loss_cls
...@@ -465,7 +466,7 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -465,7 +466,7 @@ class DETRBboxHead(DETRMapFixedNumHead):
result_dict['bbox'].append(det_preds) result_dict['bbox'].append(det_preds)
result_dict['scores'].append(scores) result_dict['scores'].append(scores)
result_dict['labels'].append(det_labels) result_dict['labels'].append(det_labels)
result_dict['lines_bs_idx'].extend([i]*nline) result_dict['lines_bs_idx'].extend([i] * nline)
# for down stream polyline # for down stream polyline
_bboxs = torch.cat(result_dict['bbox'], dim=0) _bboxs = torch.cat(result_dict['bbox'], dim=0)
...@@ -481,4 +482,4 @@ class DETRBboxHead(DETRMapFixedNumHead): ...@@ -481,4 +482,4 @@ class DETRBboxHead(DETRMapFixedNumHead):
def assign_bev(feat, idx): def assign_bev(feat, idx):
return feat[idx] return feat[idx]
\ No newline at end of file
import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import copy from mmcv.cnn import (Conv2d, Linear, bias_init_with_prob,
from mmdet.models import HEADS build_activation_layer)
from mmcv.cnn import Conv2d
from mmcv.cnn import Linear, build_activation_layer, bias_init_with_prob
from mmcv.cnn.bricks.transformer import build_positional_encoding from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils import build_transformer
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
from mmdet.core import (multi_apply, build_assigner, build_sampler, from mmdet.models import HEADS, build_loss
reduce_mean) from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.models import build_loss
from .base_map_head import BaseMapHead from .base_map_head import BaseMapHead
...@@ -60,14 +57,14 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -60,14 +57,14 @@ class DETRMapFixedNumHead(BaseMapHead):
if loss_cls['use_sigmoid']: if loss_cls['use_sigmoid']:
self.cls_out_channels = num_classes self.cls_out_channels = num_classes
else: else:
self.cls_out_channels = num_classes+1 self.cls_out_channels = num_classes + 1
self.iterative = iterative self.iterative = iterative
self.num_reg_fcs = num_reg_fcs self.num_reg_fcs = num_reg_fcs
if patch_size is not None: if patch_size is not None:
self.register_buffer('patch_size', torch.tensor( self.register_buffer('patch_size', torch.tensor(
(patch_size[1], patch_size[0])),) (patch_size[1], patch_size[0])), )
self._build_transformer(transformer, positional_encoding) self._build_transformer(transformer, positional_encoding)
...@@ -104,7 +101,7 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -104,7 +101,7 @@ class DETRMapFixedNumHead(BaseMapHead):
self.transformer = build_transformer(transformer) self.transformer = build_transformer(transformer)
self.embed_dims = self.transformer.embed_dims self.embed_dims = self.transformer.embed_dims
def _init_branch(self,): def _init_branch(self, ):
"""Initialize classification branch and regression branch of head.""" """Initialize classification branch and regression branch of head."""
fc_cls = Linear(self.embed_dims, self.cls_out_channels) fc_cls = Linear(self.embed_dims, self.cls_out_channels)
...@@ -114,8 +111,9 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -114,8 +111,9 @@ class DETRMapFixedNumHead(BaseMapHead):
reg_branch.append(Linear(self.embed_dims, self.embed_dims)) reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.LayerNorm(self.embed_dims)) reg_branch.append(nn.LayerNorm(self.embed_dims))
reg_branch.append(nn.ReLU()) reg_branch.append(nn.ReLU())
reg_branch.append(Linear(self.embed_dims, self.num_points*2)) reg_branch.append(Linear(self.embed_dims, self.num_points * 2))
reg_branch = nn.Sequential(*reg_branch) reg_branch = nn.Sequential(*reg_branch)
# add sigmoid or not # add sigmoid or not
def _get_clones(module, N): def _get_clones(module, N):
...@@ -185,7 +183,6 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -185,7 +183,6 @@ class DETRMapFixedNumHead(BaseMapHead):
outputs = [] outputs = []
for i, query_feat in enumerate(outs_dec): for i, query_feat in enumerate(outs_dec):
ocls = self.pre_branches['cls'](query_feat) ocls = self.pre_branches['cls'](query_feat)
oreg = self.pre_branches['reg'](query_feat) oreg = self.pre_branches['reg'](query_feat)
oreg = oreg.unflatten(dim=2, sizes=(self.num_points, 2)) oreg = oreg.unflatten(dim=2, sizes=(self.num_points, 2))
...@@ -235,7 +232,7 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -235,7 +232,7 @@ class DETRMapFixedNumHead(BaseMapHead):
num_pred_lines = lines_pred.size(0) num_pred_lines = lines_pred.size(0)
# assigner and sampler # assigner and sampler
assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,), assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred, ),
gts=dict(lines=gt_lines, gts=dict(lines=gt_lines,
labels=gt_labels, ), labels=gt_labels, ),
gt_bboxes_ignore=gt_bboxes_ignore) gt_bboxes_ignore=gt_bboxes_ignore)
...@@ -245,7 +242,7 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -245,7 +242,7 @@ class DETRMapFixedNumHead(BaseMapHead):
neg_inds = sampling_result.neg_inds neg_inds = sampling_result.neg_inds
# label targets # label targets
labels = gt_lines.new_full((num_pred_lines, ), labels = gt_lines.new_full((num_pred_lines,),
self.num_classes, self.num_classes,
dtype=torch.long) dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
...@@ -297,10 +294,10 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -297,10 +294,10 @@ class DETRMapFixedNumHead(BaseMapHead):
(labels_list, label_weights_list, (labels_list, label_weights_list,
lines_targets_list, lines_weights_list, lines_targets_list, lines_weights_list,
pos_inds_list, neg_inds_list) = multi_apply( pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single, self._get_target_single,
preds['scores'], preds['lines'], preds['scores'], preds['lines'],
gts['lines'], gts['labels'], gts['lines'], gts['labels'],
gt_bboxes_ignore=gt_bboxes_ignore_list) gt_bboxes_ignore=gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list))
...@@ -319,7 +316,7 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -319,7 +316,7 @@ class DETRMapFixedNumHead(BaseMapHead):
gts: dict, gts: dict,
gt_bboxes_ignore_list=None, gt_bboxes_ignore_list=None,
reduction='none'): reduction='none'):
""" """
Loss function for outputs from a single decoder layer of a single Loss function for outputs from a single decoder layer of a single
feature level. feature level.
Args: Args:
...@@ -327,7 +324,7 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -327,7 +324,7 @@ class DETRMapFixedNumHead(BaseMapHead):
for all images. Shape [bs, num_query, cls_out_channels]. for all images. Shape [bs, num_query, cls_out_channels].
lines_preds (Tensor): lines_preds (Tensor):
shape [bs, num_query, num_points, 2]. shape [bs, num_query, num_points, 2].
gt_lines_list (list[Tensor]): gt_lines_list (list[Tensor]):
with shape (num_gts, num_points, 2) with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ). image with shape (num_gts, ).
...@@ -339,7 +336,7 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -339,7 +336,7 @@ class DETRMapFixedNumHead(BaseMapHead):
""" """
# get target for each sample # get target for each sample
new_gts, num_total_pos, num_total_neg, pos_inds_list =\ new_gts, num_total_pos, num_total_neg, pos_inds_list = \
self.get_targets(preds, gts, gt_bboxes_ignore_list) self.get_targets(preds, gts, gt_bboxes_ignore_list)
# batched all data # batched all data
...@@ -348,7 +345,7 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -348,7 +345,7 @@ class DETRMapFixedNumHead(BaseMapHead):
# construct weighted avg_factor to match with the official DETR repo # construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \ cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor: if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean( cls_avg_factor = reduce_mean(
preds['scores'].new_tensor([cls_avg_factor])) preds['scores'].new_tensor([cls_avg_factor]))
...@@ -368,7 +365,8 @@ class DETRMapFixedNumHead(BaseMapHead): ...@@ -368,7 +365,8 @@ class DETRMapFixedNumHead(BaseMapHead):
lines_preds = preds['lines'].reshape(-1, self.num_points, 2) lines_preds = preds['lines'].reshape(-1, self.num_points, 2)
if reduction == 'none': # For performance analysis if reduction == 'none': # For performance analysis
loss_reg = self.reg_loss( loss_reg = self.reg_loss(
lines_preds, new_gts['lines_targets'], new_gts['lines_weights'], reduction_override=reduction, avg_factor=num_total_pos) lines_preds, new_gts['lines_targets'], new_gts['lines_weights'], reduction_override=reduction,
avg_factor=num_total_pos)
else: else:
loss_reg = self.reg_loss( loss_reg = self.reg_loss(
lines_preds, new_gts['lines_targets'], new_gts['lines_weights'], avg_factor=num_total_pos) lines_preds, new_gts['lines_targets'], new_gts['lines_weights'], avg_factor=num_total_pos)
......
import copy import numpy as np
import torch import torch
import torch.nn as nn from mmdet.models import HEADS, build_head
from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmcv.runner import force_fp32
from mmdet.models import HEADS, build_head, build_loss
from mmdet.models.utils import build_transformer from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid from mmdet.models.utils.transformer import inverse_sigmoid
from .base_map_head import BaseMapHead
import numpy as np
from ..augmentation.sythesis_det import NoiseSythesis from ..augmentation.sythesis_det import NoiseSythesis
from .base_map_head import BaseMapHead
@HEADS.register_module(force=True) @HEADS.register_module(force=True)
class DGHead(BaseMapHead): class DGHead(BaseMapHead):
...@@ -46,16 +41,16 @@ class DGHead(BaseMapHead): ...@@ -46,16 +41,16 @@ class DGHead(BaseMapHead):
self.augmentation = None self.augmentation = None
if augmentation: if augmentation:
augmentation_kwargs.update({'canvas_size':gen_net_cfg.canvas_size}) augmentation_kwargs.update({'canvas_size': gen_net_cfg.canvas_size})
self.augmentation = NoiseSythesis(**augmentation_kwargs) self.augmentation = NoiseSythesis(**augmentation_kwargs)
self.joint_training = joint_training self.joint_training = joint_training
def forward(self, batch, img_metas=None, **kwargs): def forward(self, batch, img_metas=None, **kwargs):
''' '''
Args: Args:
Returns: Returns:
outs (Dict): outs (Dict):
''' '''
if self.training: if self.training:
...@@ -68,8 +63,8 @@ class DGHead(BaseMapHead): ...@@ -68,8 +63,8 @@ class DGHead(BaseMapHead):
bbox_dict = self.det_net(context=context) bbox_dict = self.det_net(context=context)
outs = dict( outs = dict(
bbox=bbox_dict, bbox=bbox_dict,
) )
losses_dict, det_match_idxs, det_match_gt_idxs = \ losses_dict, det_match_idxs, det_match_gt_idxs = \
self.loss_det(batch, outs) self.loss_det(batch, outs)
...@@ -77,12 +72,12 @@ class DGHead(BaseMapHead): ...@@ -77,12 +72,12 @@ class DGHead(BaseMapHead):
if only_det: return outs, losses_dict if only_det: return outs, losses_dict
if self.augmentation is not None: if self.augmentation is not None:
polylines, bbox_flat =\ polylines, bbox_flat = \
self.augmentation(batch['gen'],simple_aug=True) self.augmentation(batch['gen'], simple_aug=True)
if bbox_flat is None: if bbox_flat is None:
bbox_flat = batch['gen']['bbox_flat'] bbox_flat = batch['gen']['bbox_flat']
gen_input = dict( gen_input = dict(
lines_bs_idx=batch['gen']['lines_bs_idx'], lines_bs_idx=batch['gen']['lines_bs_idx'],
lines_cls=batch['gen']['lines_cls'], lines_cls=batch['gen']['lines_cls'],
...@@ -104,32 +99,32 @@ class DGHead(BaseMapHead): ...@@ -104,32 +99,32 @@ class DGHead(BaseMapHead):
pred_bbox = bbox_dict[-1]['bboxs'].detach() pred_bbox = bbox_dict[-1]['bboxs'].detach()
else: else:
raise NotImplementedError raise NotImplementedError
# changed to original gt order. # changed to original gt order.
det_match_idx = det_match_idxs[-1] det_match_idx = det_match_idxs[-1]
det_match_gt_idx = det_match_gt_idxs[-1] det_match_gt_idx = det_match_gt_idxs[-1]
_bboxs = [] _bboxs = []
for i, (match_idx, bbox) in enumerate(zip(det_match_idx,pred_bbox)): for i, (match_idx, bbox) in enumerate(zip(det_match_idx, pred_bbox)):
_bboxs.append(bbox[match_idx]) _bboxs.append(bbox[match_idx])
_bboxs[-1] = _bboxs[-1][torch.argsort(det_match_gt_idx[i])] _bboxs[-1] = _bboxs[-1][torch.argsort(det_match_gt_idx[i])]
_bboxs = torch.cat(_bboxs, dim=0) _bboxs = torch.cat(_bboxs, dim=0)
# quantize the data # quantize the data
_bboxs = \ _bboxs = \
torch.round(_bboxs).type(torch.int32) torch.round(_bboxs).type(torch.int32)
# gen_input['bbox_flat'] = _bboxs # gen_input['bbox_flat'] = _bboxs
remain_idx = torch.randperm(_bboxs.shape[0])[:int(_bboxs.shape[0]*0.2)] remain_idx = torch.randperm(_bboxs.shape[0])[:int(_bboxs.shape[0] * 0.2)]
# for data efficient # for data efficient
for k in gen_input.keys(): for k in gen_input.keys():
if k == 'bbox_flat': if k == 'bbox_flat':
gen_input[k] = torch.cat((_bboxs,gen_input[k][remain_idx]),dim=0) gen_input[k] = torch.cat((_bboxs, gen_input[k][remain_idx]), dim=0)
else: else:
gen_input[k] = torch.cat((gen_input[k],gen_input[k][remain_idx]),dim=0) gen_input[k] = torch.cat((gen_input[k], gen_input[k][remain_idx]), dim=0)
if isinstance(context['bev_embeddings'],tuple): if isinstance(context['bev_embeddings'], tuple):
context['bev_embeddings'] = context['bev_embeddings'][0] context['bev_embeddings'] = context['bev_embeddings'][0]
poly_dict = self.gen_net(gen_input, context=context) poly_dict = self.gen_net(gen_input, context=context)
...@@ -141,17 +136,17 @@ class DGHead(BaseMapHead): ...@@ -141,17 +136,17 @@ class DGHead(BaseMapHead):
if self.joint_training: if self.joint_training:
for k in batch['gen'].keys(): for k in batch['gen'].keys():
batch['gen'][k] = \ batch['gen'][k] = \
torch.cat((batch['gen'][k],batch['gen'][k][remain_idx]),dim=0) torch.cat((batch['gen'][k], batch['gen'][k][remain_idx]), dim=0)
gen_losses_dict = \ gen_losses_dict = \
self.loss_gen(batch, outs) self.loss_gen(batch, outs)
losses_dict.update(gen_losses_dict) losses_dict.update(gen_losses_dict)
return outs, losses_dict return outs, losses_dict
def loss_det(self, gt: dict, pred: dict): def loss_det(self, gt: dict, pred: dict):
loss_dict = {} loss_dict = {}
# det # det
...@@ -159,8 +154,8 @@ class DGHead(BaseMapHead): ...@@ -159,8 +154,8 @@ class DGHead(BaseMapHead):
self.det_net.loss(gt['det'], pred['bbox']) self.det_net.loss(gt['det'], pred['bbox'])
for k, v in det_loss_dict.items(): for k, v in det_loss_dict.items():
loss_dict['det_'+k] = v loss_dict['det_' + k] = v
return loss_dict, det_match_idx, det_match_gt_idx return loss_dict, det_match_idx, det_match_gt_idx
def loss_gen(self, gt: dict, pred: dict): def loss_gen(self, gt: dict, pred: dict):
...@@ -171,34 +166,34 @@ class DGHead(BaseMapHead): ...@@ -171,34 +166,34 @@ class DGHead(BaseMapHead):
gen_loss_dict = self.gen_net.loss(gt['gen'], pred['polylines']) gen_loss_dict = self.gen_net.loss(gt['gen'], pred['polylines'])
for k, v in gen_loss_dict.items(): for k, v in gen_loss_dict.items():
loss_dict['gen_'+k] = v loss_dict['gen_' + k] = v
return loss_dict return loss_dict
def loss(self, gt: dict, pred: dict): def loss(self, gt: dict, pred: dict):
pass pass
@torch.no_grad() @torch.no_grad()
def inference(self, batch: dict={}, context: dict={}, gt_condition=False, **kwargs): def inference(self, batch: dict = {}, context: dict = {}, gt_condition=False, **kwargs):
''' '''
num_samples_batch: number of sample per batch (batch size) num_samples_batch: number of sample per batch (batch size)
''' '''
outs = {} outs = {}
bbox_dict = self.det_net(context=context) bbox_dict = self.det_net(context=context)
bbox_dict = self.det_net.post_process(bbox_dict) bbox_dict = self.det_net.post_process(bbox_dict)
outs.update(bbox_dict) outs.update(bbox_dict)
if len(outs['lines_bs_idx']) == 0: if len(outs['lines_bs_idx']) == 0:
return None return None
if isinstance(context['bev_embeddings'],tuple): if isinstance(context['bev_embeddings'], tuple):
context['bev_embeddings'] = context['bev_embeddings'][0] context['bev_embeddings'] = context['bev_embeddings'][0]
poly_dict = self.gen_net(outs, poly_dict = self.gen_net(outs,
context=context, context=context,
# max_sample_length=self.max_num_vertices, # max_sample_length=self.max_num_vertices,
max_sample_length=64, max_sample_length=64,
top_p=self.top_p_gen_model, top_p=self.top_p_gen_model,
gt_condition=gt_condition) gt_condition=gt_condition)
...@@ -206,7 +201,7 @@ class DGHead(BaseMapHead): ...@@ -206,7 +201,7 @@ class DGHead(BaseMapHead):
return outs return outs
def post_process(self, preds: dict, tokens, gts:dict=None, **kwargs): def post_process(self, preds: dict, tokens, gts: dict = None, **kwargs):
''' '''
Args: Args:
XXX XXX
...@@ -215,8 +210,8 @@ class DGHead(BaseMapHead): ...@@ -215,8 +210,8 @@ class DGHead(BaseMapHead):
''' '''
range_size = self.gen_net.canvas_size.cpu().numpy() range_size = self.gen_net.canvas_size.cpu().numpy()
coord_dim = self.gen_net.coord_dim coord_dim = self.gen_net.coord_dim
gen_net_name = self.gen_net.name if hasattr(self.gen_net,'name') else 'gen' gen_net_name = self.gen_net.name if hasattr(self.gen_net, 'name') else 'gen'
ret_list = [] ret_list = []
for batch_idx in range(len(tokens)): for batch_idx in range(len(tokens)):
...@@ -227,8 +222,8 @@ class DGHead(BaseMapHead): ...@@ -227,8 +222,8 @@ class DGHead(BaseMapHead):
det_gt = None det_gt = None
if gts is not None: if gts is not None:
det_gt, rec_groundtruth = pack_groundtruth( det_gt, rec_groundtruth = pack_groundtruth(
batch_idx,gts,tokens,range_size,gen_net_name,coord_dim=coord_dim) batch_idx, gts, tokens, range_size, gen_net_name, coord_dim=coord_dim)
bbox_res = { bbox_res = {
# 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(), # 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(),
# 'det_gt': det_gt, # 'det_gt': det_gt,
...@@ -238,7 +233,6 @@ class DGHead(BaseMapHead): ...@@ -238,7 +233,6 @@ class DGHead(BaseMapHead):
} }
ret_dict_single.update(bbox_res) ret_dict_single.update(bbox_res)
# for gen results. # for gen results.
batch2seq = np.nonzero( batch2seq = np.nonzero(
preds['lines_bs_idx'].cpu().numpy() == batch_idx)[0] preds['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
...@@ -249,16 +243,15 @@ class DGHead(BaseMapHead): ...@@ -249,16 +243,15 @@ class DGHead(BaseMapHead):
}) })
for i in batch2seq: for i in batch2seq:
pre = preds['polylines'][i].detach().cpu().numpy() pre = preds['polylines'][i].detach().cpu().numpy()
pre_msk = preds['polyline_masks'][i].detach().cpu().numpy() pre_msk = preds['polyline_masks'][i].detach().cpu().numpy()
valid_idx = np.nonzero(pre_msk)[0][:-1] valid_idx = np.nonzero(pre_msk)[0][:-1]
# From [200,1] to [199,0] to (1,0) # From [200,1] to [199,0] to (1,0)
line = (pre[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1) line = (pre[valid_idx].reshape(-1, coord_dim) - 1) / (range_size - 1)
ret_dict_single['vectors'].append(line) ret_dict_single['vectors'].append(line)
# if gts is not None: # if gts is not None:
# ret_dict_single['groundTruth'] = rec_groundtruth # ret_dict_single['groundTruth'] = rec_groundtruth
...@@ -266,8 +259,8 @@ class DGHead(BaseMapHead): ...@@ -266,8 +259,8 @@ class DGHead(BaseMapHead):
return ret_list return ret_list
def pack_groundtruth(batch_idx,gts,tokens,range_size,gen_net_name='gen',coord_dim=2):
def pack_groundtruth(batch_idx, gts, tokens, range_size, gen_net_name='gen', coord_dim=2):
if 'keypoints' in gts['det']: if 'keypoints' in gts['det']:
gt_bbox = \ gt_bbox = \
gts['det']['keypoints'][batch_idx].detach().cpu().numpy() gts['det']['keypoints'][batch_idx].detach().cpu().numpy()
...@@ -281,7 +274,7 @@ def pack_groundtruth(batch_idx,gts,tokens,range_size,gen_net_name='gen',coord_di ...@@ -281,7 +274,7 @@ def pack_groundtruth(batch_idx,gts,tokens,range_size,gen_net_name='gen',coord_di
batch2seq = np.nonzero( batch2seq = np.nonzero(
gts['gen']['lines_bs_idx'].cpu().numpy() == batch_idx)[0] gts['gen']['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
ret_groundtruth = { ret_groundtruth = {
'token': tokens[batch_idx], 'token': tokens[batch_idx],
'nline': len(batch2seq), 'nline': len(batch2seq),
...@@ -290,16 +283,16 @@ def pack_groundtruth(batch_idx,gts,tokens,range_size,gen_net_name='gen',coord_di ...@@ -290,16 +283,16 @@ def pack_groundtruth(batch_idx,gts,tokens,range_size,gen_net_name='gen',coord_di
} }
for i in batch2seq: for i in batch2seq:
gt_line =\ gt_line = \
gts['gen']['polylines'].detach().cpu().numpy()[i] gts['gen']['polylines'].detach().cpu().numpy()[i]
gt_msk = gts['gen']['polyline_masks'].detach().cpu().numpy()[i] gt_msk = gts['gen']['polyline_masks'].detach().cpu().numpy()[i]
if gen_net_name == 'gen_gmm': if gen_net_name == 'gen_gmm':
valid_idx = np.nonzero(gt_msk)[0] valid_idx = np.nonzero(gt_msk)[0]
else: else:
valid_idx = np.nonzero(gt_msk)[0][:-1] valid_idx = np.nonzero(gt_msk)[0][:-1]
# From [200,1] to [199,0] to (1,0) # From [200,1] to [199,0] to (1,0)
line = (gt_line[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1) line = (gt_line[valid_idx].reshape(-1, coord_dim) - 1) / (range_size - 1)
ret_groundtruth['lines'].append(line) ret_groundtruth['lines'].append(line)
return det_gt, ret_groundtruth return det_gt, ret_groundtruth
import copy import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import Conv2d, Linear from mmcv.cnn import (Conv2d, Linear, bias_init_with_prob,
from mmcv.runner import force_fp32 build_activation_layer)
from torch.distributions.categorical import Categorical
from mmdet.core import (multi_apply, build_assigner, build_sampler,
reduce_mean)
from mmdet.models import HEADS
from .detr_bbox import DETRBboxHead
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.models import build_loss
from mmcv.cnn import Linear, build_activation_layer, bias_init_with_prob
from mmcv.cnn.bricks.transformer import build_positional_encoding from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmcv.runner import force_fp32
from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
from mmdet.models import HEADS, build_loss
from mmdet.models.utils import build_transformer from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid
@HEADS.register_module(force=True) @HEADS.register_module(force=True)
class MapElementDetector(nn.Module): class MapElementDetector(nn.Module):
def __init__(self, def __init__(self,
canvas_size=(400, 200), canvas_size=(400, 200),
discrete_output=False, discrete_output=False,
separate_detect=False, separate_detect=False,
mode='xyxy', mode='xyxy',
bbox_size=None, bbox_size=None,
coord_dim=2, coord_dim=2,
kp_coord_dim=2, kp_coord_dim=2,
num_classes=3, num_classes=3,
in_channels=128, in_channels=128,
...@@ -41,8 +38,8 @@ class MapElementDetector(nn.Module): ...@@ -41,8 +38,8 @@ class MapElementDetector(nn.Module):
positional_encoding: dict = None, positional_encoding: dict = None,
loss_cls: dict = None, loss_cls: dict = None,
loss_reg: dict = None, loss_reg: dict = None,
train_cfg: dict = None,): train_cfg: dict = None, ):
super().__init__() super().__init__()
assigner = train_cfg['assigner'] assigner = train_cfg['assigner']
...@@ -65,7 +62,7 @@ class MapElementDetector(nn.Module): ...@@ -65,7 +62,7 @@ class MapElementDetector(nn.Module):
if loss_cls['use_sigmoid']: if loss_cls['use_sigmoid']:
self.cls_out_channels = num_classes self.cls_out_channels = num_classes
else: else:
self.cls_out_channels = num_classes+1 self.cls_out_channels = num_classes + 1
self.iterative = iterative self.iterative = iterative
self.num_reg_fcs = num_reg_fcs self.num_reg_fcs = num_reg_fcs
...@@ -82,7 +79,7 @@ class MapElementDetector(nn.Module): ...@@ -82,7 +79,7 @@ class MapElementDetector(nn.Module):
self.separate_detect = separate_detect self.separate_detect = separate_detect
self.discrete_output = discrete_output self.discrete_output = discrete_output
self.bbox_size = 3 if mode=='sce' else 2 self.bbox_size = 3 if mode == 'sce' else 2
if bbox_size is not None: if bbox_size is not None:
self.bbox_size = bbox_size self.bbox_size = bbox_size
self.coord_dim = coord_dim # for xyz self.coord_dim = coord_dim # for xyz
...@@ -115,16 +112,16 @@ class MapElementDetector(nn.Module): ...@@ -115,16 +112,16 @@ class MapElementDetector(nn.Module):
# query_pos_embed & query_embed # query_pos_embed & query_embed
self.query_embedding = nn.Embedding(self.num_query, self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims*2) self.embed_dims * 2)
# for bbox parameter xstart, ystart, xend, yend # for bbox parameter xstart, ystart, xend, yend
self.bbox_embedding = nn.Embedding( self.bbox_size, self.bbox_embedding = nn.Embedding(self.bbox_size,
self.embed_dims*2) self.embed_dims * 2)
def _init_branch(self,): def _init_branch(self, ):
"""Initialize classification branch and regression branch of head.""" """Initialize classification branch and regression branch of head."""
fc_cls = Linear(self.embed_dims*self.bbox_size, self.cls_out_channels) fc_cls = Linear(self.embed_dims * self.bbox_size, self.cls_out_channels)
# fc_cls = Linear(self.embed_dims, self.cls_out_channels) # fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch = [] reg_branch = []
...@@ -135,12 +132,13 @@ class MapElementDetector(nn.Module): ...@@ -135,12 +132,13 @@ class MapElementDetector(nn.Module):
if self.discrete_output: if self.discrete_output:
reg_branch.append(nn.Linear( reg_branch.append(nn.Linear(
self.embed_dims, max(self.canvas_size), bias=True,)) self.embed_dims, max(self.canvas_size), bias=True, ))
else: else:
reg_branch.append(nn.Linear( reg_branch.append(nn.Linear(
self.embed_dims, self.coord_dim, bias=True,)) self.embed_dims, self.coord_dim, bias=True, ))
reg_branch = nn.Sequential(*reg_branch) reg_branch = nn.Sequential(*reg_branch)
# add sigmoid or not # add sigmoid or not
def _get_clones(module, N): def _get_clones(module, N):
...@@ -240,29 +238,29 @@ class MapElementDetector(nn.Module): ...@@ -240,29 +238,29 @@ class MapElementDetector(nn.Module):
[nb_dec, bs, num_query, num_points, 2]. [nb_dec, bs, num_query, num_points, 2].
''' '''
(global_context_embedding, sequential_context_embeddings) =\ (global_context_embedding, sequential_context_embeddings) = \
self._prepare_context(context) self._prepare_context(context)
x = sequential_context_embeddings x = sequential_context_embeddings
B, C, H, W = x.shape B, C, H, W = x.shape
query_embedding = self.query_embedding.weight[None,:,None].repeat(B, 1, self.bbox_size, 1) query_embedding = self.query_embedding.weight[None, :, None].repeat(B, 1, self.bbox_size, 1)
bbox_embed = self.bbox_embedding.weight bbox_embed = self.bbox_embedding.weight
query_embedding = query_embedding + bbox_embed[None,None] query_embedding = query_embedding + bbox_embed[None, None]
query_embedding = query_embedding.view(B, -1, C*2) query_embedding = query_embedding.view(B, -1, C * 2)
img_masks = x.new_zeros((B, H, W)) img_masks = x.new_zeros((B, H, W))
pos_embed = self.positional_encoding(img_masks) pos_embed = self.positional_encoding(img_masks)
# outs_dec: [nb_dec, bs, num_query, embed_dim] # outs_dec: [nb_dec, bs, num_query, embed_dim]
hs, init_reference, inter_references = self.transformer( hs, init_reference, inter_references = self.transformer(
[x,], [x, ],
[img_masks.type(torch.bool)], [img_masks.type(torch.bool)],
query_embedding, query_embedding,
[pos_embed], [pos_embed],
reg_branches= self.reg_branches if self.iterative else None, # noqa:E501 reg_branches=self.reg_branches if self.iterative else None, # noqa:E501
cls_branches= None, # noqa:E501 cls_branches=None, # noqa:E501
) )
outs_dec = hs.permute(0, 2, 1, 3) outs_dec = hs.permute(0, 2, 1, 3)
outputs = [] outputs = []
...@@ -271,23 +269,23 @@ class MapElementDetector(nn.Module): ...@@ -271,23 +269,23 @@ class MapElementDetector(nn.Module):
reference = init_reference reference = init_reference
else: else:
reference = inter_references[i - 1] reference = inter_references[i - 1]
outputs.append(self.get_prediction(i,query_feat,reference)) outputs.append(self.get_prediction(i, query_feat, reference))
return outputs return outputs
def get_prediction(self, level, query_feat, reference): def get_prediction(self, level, query_feat, reference):
bs, num_query, h = query_feat.shape bs, num_query, h = query_feat.shape
query_feat = query_feat.view(bs, -1, self.bbox_size,h) query_feat = query_feat.view(bs, -1, self.bbox_size, h)
ocls = self.pre_branches['cls'][level](query_feat.flatten(-2)) ocls = self.pre_branches['cls'][level](query_feat.flatten(-2))
# ocls = ocls.mean(-2) # ocls = ocls.mean(-2)
reference = inverse_sigmoid(reference) reference = inverse_sigmoid(reference)
reference = reference.view(bs, -1, self.bbox_size,self.coord_dim) reference = reference.view(bs, -1, self.bbox_size, self.coord_dim)
tmp = self.pre_branches['reg'][level](query_feat) tmp = self.pre_branches['reg'][level](query_feat)
tmp[...,:self.kp_coord_dim] = tmp[...,:self.kp_coord_dim] + reference[...,:self.kp_coord_dim] tmp[..., :self.kp_coord_dim] = tmp[..., :self.kp_coord_dim] + reference[..., :self.kp_coord_dim]
lines = tmp.sigmoid() # bs, num_query, self.bbox_size,2 lines = tmp.sigmoid() # bs, num_query, self.bbox_size,2
lines = lines * self.canvas_size[:self.coord_dim] lines = lines * self.canvas_size[:self.coord_dim]
lines = lines.flatten(-2) lines = lines.flatten(-2)
...@@ -295,7 +293,7 @@ class MapElementDetector(nn.Module): ...@@ -295,7 +293,7 @@ class MapElementDetector(nn.Module):
return dict( return dict(
lines=lines, # [bs, num_query, bboxsize*2] lines=lines, # [bs, num_query, bboxsize*2]
scores=ocls, # [bs, num_query, num_class] scores=ocls, # [bs, num_query, num_class]
embeddings= query_feat, # [bs, num_query, bbox_size, h] embeddings=query_feat, # [bs, num_query, bbox_size, h]
) )
@force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines')) @force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines'))
...@@ -333,7 +331,7 @@ class MapElementDetector(nn.Module): ...@@ -333,7 +331,7 @@ class MapElementDetector(nn.Module):
num_pred_lines = len(lines_pred) num_pred_lines = len(lines_pred)
# assigner and sampler # assigner and sampler
assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,), assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred, ),
gts=dict(lines=gt_lines, gts=dict(lines=gt_lines,
labels=gt_labels, ), labels=gt_labels, ),
gt_bboxes_ignore=gt_bboxes_ignore) gt_bboxes_ignore=gt_bboxes_ignore)
...@@ -345,10 +343,10 @@ class MapElementDetector(nn.Module): ...@@ -345,10 +343,10 @@ class MapElementDetector(nn.Module):
# label targets 0: foreground, 1: background # label targets 0: foreground, 1: background
if self.separate_detect: if self.separate_detect:
labels = gt_lines.new_full((num_pred_lines, ), 1, dtype=torch.long) labels = gt_lines.new_full((num_pred_lines,), 1, dtype=torch.long)
else: else:
labels = gt_lines.new_full( labels = gt_lines.new_full(
(num_pred_lines, ), self.num_classes, dtype=torch.long) (num_pred_lines,), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_lines.new_ones(num_pred_lines) label_weights = gt_lines.new_ones(num_pred_lines)
...@@ -421,11 +419,11 @@ class MapElementDetector(nn.Module): ...@@ -421,11 +419,11 @@ class MapElementDetector(nn.Module):
(labels_list, label_weights_list, (labels_list, label_weights_list,
lines_targets_list, lines_weights_list, lines_targets_list, lines_weights_list,
pos_inds_list, neg_inds_list,pos_gt_inds_list) = multi_apply( pos_inds_list, neg_inds_list, pos_gt_inds_list) = multi_apply(
self._get_target_single, self._get_target_single,
preds['scores'], lines_pred, preds['scores'], lines_pred,
class_label, bbox, class_label, bbox,
gt_bboxes_ignore=gt_bboxes_ignore_list) gt_bboxes_ignore=gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list))
...@@ -464,7 +462,7 @@ class MapElementDetector(nn.Module): ...@@ -464,7 +462,7 @@ class MapElementDetector(nn.Module):
""" """
# Get target for each sample # Get target for each sample
new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list =\ new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list = \
self.get_targets(preds, gts, gt_bboxes_ignore_list) self.get_targets(preds, gts, gt_bboxes_ignore_list)
# Batched all data # Batched all data
...@@ -473,7 +471,7 @@ class MapElementDetector(nn.Module): ...@@ -473,7 +471,7 @@ class MapElementDetector(nn.Module):
# construct weighted avg_factor to match with the official DETR repo # construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \ cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor: if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean( cls_avg_factor = reduce_mean(
preds['scores'].new_tensor([cls_avg_factor])) preds['scores'].new_tensor([cls_avg_factor]))
...@@ -499,7 +497,7 @@ class MapElementDetector(nn.Module): ...@@ -499,7 +497,7 @@ class MapElementDetector(nn.Module):
# position NLL loss # position NLL loss
if self.discrete_output: if self.discrete_output:
loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) * loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) *
new_gts['bboxs_weights']).sum()/(num_total_pos) new_gts['bboxs_weights']).sum() / (num_total_pos)
else: else:
loss_reg = self.reg_loss( loss_reg = self.reg_loss(
preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos) preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos)
...@@ -613,7 +611,7 @@ class MapElementDetector(nn.Module): ...@@ -613,7 +611,7 @@ class MapElementDetector(nn.Module):
result_dict['bbox'].append(det_preds) result_dict['bbox'].append(det_preds)
result_dict['scores'].append(scores) result_dict['scores'].append(scores)
result_dict['labels'].append(det_labels) result_dict['labels'].append(det_labels)
result_dict['lines_bs_idx'].extend([i]*nline) result_dict['lines_bs_idx'].extend([i] * nline)
# for down stream polyline # for down stream polyline
_bboxs = torch.cat(result_dict['bbox'], dim=0) _bboxs = torch.cat(result_dict['bbox'], dim=0)
...@@ -625,4 +623,4 @@ class MapElementDetector(nn.Module): ...@@ -625,4 +623,4 @@ class MapElementDetector(nn.Module):
result_dict['lines_bs_idx'] = torch.tensor( result_dict['lines_bs_idx'] = torch.tensor(
result_dict['lines_bs_idx'], device=device).long() result_dict['lines_bs_idx'], device=device).long()
return result_dict return result_dict
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment