Commit 61ef00a6 authored by rusty1s's avatar rusty1s
Browse files

fix knn and radius for unequal batch sizes

parent 10049daf
...@@ -31,6 +31,8 @@ def get_extensions(): ...@@ -31,6 +31,8 @@ def get_extensions():
for main, suffix in product(main_files, suffices): for main, suffix in product(main_files, suffices):
define_macros = [] define_macros = []
extra_compile_args = {'cxx': ['-O2']} extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare']
extra_link_args = ['-s'] extra_link_args = ['-s']
info = parallel_info() info = parallel_info()
...@@ -49,6 +51,8 @@ def get_extensions(): ...@@ -49,6 +51,8 @@ def get_extensions():
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['--expt-relaxed-constexpr', '-O2'] nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
if not os.name == 'nt': # Not on Windows:
nvcc_flags += ['-Wno-sign-compare']
extra_compile_args['nvcc'] = nvcc_flags extra_compile_args['nvcc'] = nvcc_flags
name = main.split(os.sep)[-1][:-4] name = main.split(os.sep)[-1][:-4]
......
...@@ -50,27 +50,22 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -50,27 +50,22 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous() x, y = x.contiguous(), y.contiguous()
ptr_x: Optional[torch.Tensor] = None batch_size = 1
if batch_x is not None: if batch_x is not None:
assert x.size(0) == batch_x.numel() assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1 batch_size = int(batch_x.max()) + 1
deg = x.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
ptr_x = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_x[1:])
ptr_y: Optional[torch.Tensor] = None
if batch_y is not None: if batch_y is not None:
assert y.size(0) == batch_y.numel() assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.max()) + 1 batch_size = max(batch_size, int(batch_y.max()) + 1)
deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
ptr_y = deg.new_zeros(batch_size + 1) ptr_x: Optional[torch.Tensor] = None
torch.cumsum(deg, 0, out=ptr_y[1:]) ptr_y: Optional[torch.Tensor] = None
if batch_size > 1:
assert batch_x is not None
assert batch_y is not None
arange = torch.arange(batch_size + 1, device=x.device)
ptr_x = torch.bucketize(arange, batch_x)
ptr_y = torch.bucketize(arange, batch_y)
return torch.ops.torch_cluster.knn(x, y, ptr_x, ptr_y, k, cosine, return torch.ops.torch_cluster.knn(x, y, ptr_x, ptr_y, k, cosine,
num_workers) num_workers)
......
...@@ -50,27 +50,22 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -50,27 +50,22 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous() x, y = x.contiguous(), y.contiguous()
ptr_x: Optional[torch.Tensor] = None batch_size = 1
if batch_x is not None: if batch_x is not None:
assert x.size(0) == batch_x.numel() assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1 batch_size = int(batch_x.max()) + 1
deg = x.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))
ptr_x = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_x[1:])
ptr_y: Optional[torch.Tensor] = None
if batch_y is not None: if batch_y is not None:
assert y.size(0) == batch_y.numel() assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.max()) + 1 batch_size = max(batch_size, int(batch_y.max()) + 1)
deg = y.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
ptr_y = deg.new_zeros(batch_size + 1) ptr_x: Optional[torch.Tensor] = None
torch.cumsum(deg, 0, out=ptr_y[1:]) ptr_y: Optional[torch.Tensor] = None
if batch_size > 1:
assert batch_x is not None
assert batch_y is not None
arange = torch.arange(batch_size + 1, device=x.device)
ptr_x = torch.bucketize(arange, batch_x)
ptr_y = torch.bucketize(arange, batch_y)
return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r, return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
max_num_neighbors, num_workers) max_num_neighbors, num_workers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment