Commit 2f84fce1 authored by rusty1s's avatar rusty1s
Browse files

torch1.2

parent 13b7fbaa
...@@ -79,7 +79,7 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False): ...@@ -79,7 +79,7 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
dist = torch.from_numpy(dist).to(x.dtype) dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long) col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k) row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k)
mask = 1 - torch.isinf(dist).view(-1) mask = ~torch.isinf(dist).view(-1)
row, col = row.view(-1)[mask], col.view(-1)[mask] row, col = row.view(-1)[mask], col.view(-1)[mask]
return torch.stack([row, col], dim=0) return torch.stack([row, col], dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment