Commit 878e1193 authored by rusty1s's avatar rusty1s
Browse files

new api

parent cba1cdb0
...@@ -56,7 +56,7 @@ from torch_cluster import graclus_cluster ...@@ -56,7 +56,7 @@ from torch_cluster import graclus_cluster
row = torch.tensor([0, 1, 1, 2]) row = torch.tensor([0, 1, 1, 2])
col = torch.tensor([1, 0, 2, 1]) col = torch.tensor([1, 0, 2, 1])
weight = torch.tensor([1, 1, 1, 1]) # Optional edge weights. weight = torch.Tensor([1, 1, 1, 1]) # Optional edge weights.
cluster = graclus_cluster(row, col, weight) cluster = graclus_cluster(row, col, weight)
``` ```
...@@ -74,8 +74,8 @@ A clustering algorithm, which overlays a regular grid of user-defined size over ...@@ -74,8 +74,8 @@ A clustering algorithm, which overlays a regular grid of user-defined size over
import torch import torch
from torch_cluster import grid_cluster from torch_cluster import grid_cluster
pos = torch.tensor([[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]]) pos = torch.Tensor([[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]])
size = torch.tensor([5, 5]) size = torch.Tensor([5, 5])
cluster = grid_cluster(pos, size) cluster = grid_cluster(pos, size)
``` ```
......
from .utils.ffi import grid import torch
import grid_cpu
if torch.cuda.is_available():
import grid_cuda
def grid_cluster(pos, size, start=None, end=None): def grid_cluster(pos, size, start=None, end=None):
...@@ -8,9 +12,9 @@ def grid_cluster(pos, size, start=None, end=None): ...@@ -8,9 +12,9 @@ def grid_cluster(pos, size, start=None, end=None):
Args: Args:
pos (Tensor): D-dimensional position of points. pos (Tensor): D-dimensional position of points.
size (Tensor): Size of a voxel in each dimension. size (Tensor): Size of a voxel in each dimension.
start (Tensor or int, optional): Start position of the grid (in each start (Tensor, optional): Start position of the grid (in each
dimension). (default: :obj:`None`) dimension). (default: :obj:`None`)
end (Tensor or int, optional): End position of the grid (in each end (Tensor, optional): End position of the grid (in each
dimension). (default: :obj:`None`) dimension). (default: :obj:`None`)
Examples:: Examples::
...@@ -21,18 +25,12 @@ def grid_cluster(pos, size, start=None, end=None): ...@@ -21,18 +25,12 @@ def grid_cluster(pos, size, start=None, end=None):
""" """
pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
assert pos.size(1) == size.size(0), (
'Last dimension of position tensor must have same size as size tensor')
start = pos.t().min(dim=1)[0] if start is None else start start = pos.t().min(dim=1)[0] if start is None else start
end = pos.t().max(dim=1)[0] if end is None else end end = pos.t().max(dim=1)[0] if end is None else end
pos, end = pos - start, end - start
size = size.type_as(pos)
count = (end / size).long() + 1
cluster = count.new(pos.size(0)) if pos.is_cuda:
grid(cluster, pos, size, count) cluster = grid_cuda.grid(pos, size, start, end)
else:
cluster = grid_cpu.grid(pos, size, start, end)
return cluster return cluster
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment