Commit ae73ea73 authored by rusty1s's avatar rusty1s
Browse files

bugfix batch == None

parent 199be867
...@@ -26,6 +26,9 @@ def fps(x, batch=None, ratio=0.5, random_start=True): ...@@ -26,6 +26,9 @@ def fps(x, batch=None, ratio=0.5, random_start=True):
>>> sample = fps(pos, batch) >>> sample = fps(pos, batch)
""" """
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
assert x.is_cuda assert x.is_cuda
assert x.dim() <= 2 and batch.dim() == 1 assert x.dim() <= 2 and batch.dim() == 1
assert x.size(0) == batch.size(0) assert x.size(0) == batch.size(0)
...@@ -33,9 +36,6 @@ def fps(x, batch=None, ratio=0.5, random_start=True): ...@@ -33,9 +36,6 @@ def fps(x, batch=None, ratio=0.5, random_start=True):
x = x.view(-1, 1) if x.dim() == 1 else x x = x.view(-1, 1) if x.dim() == 1 else x
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
op = fps_cuda.fps if x.is_cuda else None op = fps_cuda.fps if x.is_cuda else None
out = op(x, batch, ratio, random_start) out = op(x, batch, ratio, random_start)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment