Commit 8447a475 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support PointResidualCoder

parent a2e7d474
import torch
import numpy as np
class ResidualCoder(object):
......@@ -123,3 +124,84 @@ class PreviousResidualRoIDecoder(object):
cgs = [t + a for t, a in zip(cts, cas)]
return torch.cat([xg, yg, zg, dxg, dyg, dzg, rg, *cgs], dim=-1)
class PointResidualCoder(object):
def __init__(self, code_size=8, use_mean_size=True, **kwargs):
super().__init__()
self.code_size = code_size
self.use_mean_size = use_mean_size
if self.use_mean_size:
self.mean_size = torch.from_numpy(np.array(kwargs['mean_size'])).cuda().float()
assert self.mean_size.min() > 0
def encode_torch(self, gt_boxes, points, gt_classes=None):
"""
Args:
gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
points: (N, 3) [x, y, z]
gt_classes: (N) [1, num_classes]
Returns:
box_coding: (N, 8 + C)
"""
gt_boxes[:, 3:6] = torch.clamp_min(gt_boxes[:, 3:6], min=1e-5)
xg, yg, zg, dxg, dyg, dzg, rg, *cgs = torch.split(gt_boxes, 1, dim=-1)
xa, ya, za = torch.split(points, 1, dim=-1)
if self.use_mean_size:
assert gt_classes.max() <= self.mean_size.shape[0]
point_anchor_size = self.mean_size[gt_classes - 1]
dxa, dya, dza = torch.split(point_anchor_size, 1, dim=-1)
diagonal = torch.sqrt(dxa ** 2 + dya ** 2)
xt = (xg - xa) / diagonal
yt = (yg - ya) / diagonal
zt = (zg - za) / dza
dxt = torch.log(dxg / dxa)
dyt = torch.log(dyg / dya)
dzt = torch.log(dzg / dza)
else:
xt = (xg - xa)
yt = (yg - ya)
zt = (zg - za)
dxt = torch.log(dxg)
dyt = torch.log(dyg)
dzt = torch.log(dzg)
cts = [g for g in cgs]
return torch.cat([xt, yt, zt, dxt, dyt, dzt, torch.cos(rg), torch.sin(rg), *cts], dim=-1)
def decode_torch(self, box_encodings, points, pred_classes=None):
"""
Args:
box_encodings: (N, 8 + C) [x, y, z, dx, dy, dz, cos, sin, ...]
points: [x, y, z]
pred_classes: (N) [1, num_classes]
Returns:
"""
xt, yt, zt, dxt, dyt, dzt, cost, sint, *cts = torch.split(box_encodings, 1, dim=-1)
xa, ya, za = torch.split(points, 1, dim=-1)
if self.use_mean_size:
assert pred_classes.max() <= self.mean_size.shape[0]
point_anchor_size = self.mean_size[pred_classes - 1]
dxa, dya, dza = torch.split(point_anchor_size, 1, dim=-1)
diagonal = torch.sqrt(dxa ** 2 + dya ** 2)
xg = xt * diagonal + xa
yg = yt * diagonal + ya
zg = zt * dza + za
dxg = torch.exp(dxt) * dxa
dyg = torch.exp(dyt) * dya
dzg = torch.exp(dzt) * dza
else:
xg = xt + xa
yg = yt + ya
zg = zt + za
dxg, dyg, dzg = torch.split(torch.exp(box_encodings[..., 3:6]), 1, dim=-1)
rg = torch.atan2(sint, cost)
cgs = [t for t in cts]
return torch.cat([xg, yg, zg, dxg, dyg, dzg, rg, *cgs], dim=-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment