Commit 61ef00a6 authored by rusty1s's avatar rusty1s
Browse files

fix knn and radius for unequal batch sizes

parent 10049daf
......@@ -31,6 +31,8 @@ def get_extensions():
for main, suffix in product(main_files, suffices):
define_macros = []
extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare']
extra_link_args = ['-s']
info = parallel_info()
......@@ -49,6 +51,8 @@ def get_extensions():
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
if not os.name == 'nt': # Not on Windows:
nvcc_flags += ['-Wno-sign-compare']
extra_compile_args['nvcc'] = nvcc_flags
name = main.split(os.sep)[-1][:-4]
......
......@@ -50,27 +50,22 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()
ptr_x: Optional[torch.Tensor] = None
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
deg = x.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
ptr_x = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_x[1:])
ptr_y: Optional[torch.Tensor] = None
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.max()) + 1
deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
batch_size = max(batch_size, int(batch_y.max()) + 1)
ptr_y = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_y[1:])
ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None
if batch_size > 1:
assert batch_x is not None
assert batch_y is not None
arange = torch.arange(batch_size + 1, device=x.device)
ptr_x = torch.bucketize(arange, batch_x)
ptr_y = torch.bucketize(arange, batch_y)
return torch.ops.torch_cluster.knn(x, y, ptr_x, ptr_y, k, cosine,
num_workers)
......
......@@ -50,27 +50,22 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()
ptr_x: Optional[torch.Tensor] = None
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
deg = x.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
ptr_x = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_x[1:])
ptr_y: Optional[torch.Tensor] = None
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.max()) + 1
deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
batch_size = max(batch_size, int(batch_y.max()) + 1)
ptr_y = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_y[1:])
ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None
if batch_size > 1:
assert batch_x is not None
assert batch_y is not None
arange = torch.arange(batch_size + 1, device=x.device)
ptr_x = torch.bucketize(arange, batch_x)
ptr_y = torch.bucketize(arange, batch_y)
return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
max_num_neighbors, num_workers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment