Commit 34f083a9 authored by rusty1s's avatar rusty1s
Browse files

potential windows/numpy conversion fix

parent 7aee6467
......@@ -67,10 +67,12 @@ def knn(x, y, k, batch_x=None, batch_y=None):
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x)
dist, col = tree.query(y, k=k, distance_upper_bound=x.size(1))
dist, col = torch.tensor(dist), torch.tensor(col)
row = torch.arange(col.size(0)).view(-1, 1).repeat(1, k)
tree = scipy.spatial.cKDTree(x.detach().numpy())
dist, col = tree.query(
y.detach().cpu(), k=k, distance_upper_bound=x.size(1))
dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k)
mask = 1 - torch.isinf(dist).view(-1)
row, col = row.view(-1)[mask], col.view(-1)[mask]
......
......@@ -64,4 +64,6 @@ def nearest(x, y, batch_x=None, batch_y=None):
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
return torch.from_numpy(scipy.cluster.vq.vq(x, y)[0]).to(torch.long)
return torch.from_numpy(
scipy.cluster.vq.vq(x.detach().cpu(),
y.detach().cpu())[0]).to(torch.long)
......@@ -64,12 +64,13 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x)
_, col = tree.query(y, k=max_num_neighbors, distance_upper_bound=r + 1e-8)
col = [torch.tensor(c) for c in col]
tree = scipy.spatial.cKDTree(x.detach().numpy())
_, col = tree.query(
y.detach().numpy(), k=max_num_neighbors, distance_upper_bound=r + 1e-8)
col = [torch.from_numpy(c).to(torch.long) for c in col]
row = [torch.full_like(c, i) for i, c in enumerate(col)]
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
mask = col < tree.n
mask = col < int(tree.n)
return torch.stack([row[mask], col[mask]], dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment