"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "90944c5aa617ae97db143b73e3f763c4c18d8df9"
Commit 7cef0ed8 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

update utils.voxelize_pointcloud

parent 6919d707
......@@ -28,8 +28,8 @@ class Model(nn.Module):
self.sparseModel = scn.Sequential().add(
scn.InputLayer(data.dimension,data.full_scale, mode=4)).add(
scn.SubmanifoldConvolution(data.dimension, 3, m, 3, False)).add(
scn.UNet(data.dimension, block_reps, [m, 2*m, 3*m, 4*m, 5*m, 6*m, 7*m], residual_blocks)).add(
scn.BatchNormReLU(m)).add(
scn.UNet(data.dimension, block_reps, [m, 2*m, 3*m, 4*m, 5*m, 6*m, 7*m], residual_blocks)
).add(
scn.OutputLayer(data.dimension))
self.linear = nn.Linear(m, 20)
def forward(self,x):
......
......@@ -190,23 +190,28 @@ def squareroot_rotation(a):
scipy.spatial.transform.Rotation.from_dcm(torch.stack([torch.eye(3),a])))([0.5]).as_dcm()
return torch.from_numpy(b).float()[0]
def voxelize_pointcloud(xyz,rgb,average=True,accumulate=False):
def voxelize_pointcloud(xyz,rgb,average=True,accumulate=False,return_inverse=False,return_counts=False):
if xyz.numel()==0:
return xyz, rgb
if average or accumulate:
xyz,inv,counts=np.unique(xyz.numpy(),axis=0,return_inverse=True,return_counts=True)
xyz=torch.from_numpy(xyz)
inv=torch.from_numpy(inv)
rgb_out=torch.zeros(xyz.size(0),rgb.size(1),dtype=torch.float32)
xyz,inv,counts=torch.unique(xyz,dim=0,return_inverse=True,return_counts=True)
rgb_out=torch.zeros(xyz.size(0),rgb.size(1),dtype=torch.float32,device=xyz.device)
rgb_out.index_add_(0,inv,rgb)
if average:
rgb=rgb_out/torch.from_numpy(counts[:,None]).float()
return xyz, rgb
rgb_out=rgb_out/counts[:,None].float()
rgb=rgb_out
else:
xyz,idxs=np.unique(xyz,axis=0,return_index=True)
xyz=torch.from_numpy(xyz)
rgb=rgb[idxs]
return xyz, rgb
#https://github.com/pytorch/pytorch/issues/36748
xyz, inv = torch.unique(xyz, sorted=True, return_inverse=True, dim=0)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
inv, perm = inv.flip([0]), perm.flip([0])
rgb=rgb[inv.new_empty(xyz.size(0)).scatter_(0, inv, perm)]
r=[xyz,rgb_out]
if return_inverse:
r+=[inv]
if return_counts:
r+=[counts]
return r
class checkpointFunction(torch.autograd.Function):
@staticmethod
......@@ -232,6 +237,7 @@ class checkpointFunction(torch.autograd.Function):
return None, x_features.grad, None, None
def checkpoint101(run_function, x, down=1):
x.features.requires_grad = True
f=checkpointFunction.apply(run_function, x.features, x.metadata, x.spatial_size)
s=x.spatial_size//down
return SparseConvNetTensor(f, x.metadata, s)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment