Unverified Commit bcdb5a2c authored by Yezhen Cong's avatar Yezhen Cong Committed by GitHub
Browse files

[Fix] Fix a bug in KNN op (#371)

* support knn gpu op

* made it more robust and fixed comments

* bugfix and add warning in docstring
parent 28828eef
......@@ -25,7 +25,8 @@ class KNN(Function):
center_xyz (Tensor): (B, npoint, 3) if transposed == False,
else (B, 3, npoint). centers of the knn query.
transposed (bool): whether the input tensors are transposed.
defaults to False.
defaults to False. Should not expicitly use this keyword
when calling knn (=KNN.apply), just add the fourth param.
Returns:
Tensor: (B, k, npoint) tensor with the indicies of
......@@ -33,13 +34,13 @@ class KNN(Function):
"""
assert k > 0
B, npoint = center_xyz.shape[:2]
N = xyz.shape[1]
if not transposed:
xyz = xyz.transpose(2, 1).contiguous()
center_xyz = center_xyz.transpose(2, 1).contiguous()
B, _, npoint = center_xyz.shape
N = xyz.shape[2]
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
......
......@@ -110,6 +110,11 @@ def test_knn():
expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1)
assert torch.all(idx == expected_idx)
idx = knn(5,
xyz.transpose(1, 2).contiguous(),
new_xyz.transpose(1, 2).contiguous(), True)
assert torch.all(idx == expected_idx)
idx = knn(5, xyz, xyz)
xyz_ = xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1)
xyz__ = xyz.unsqueeze(1).repeat(1, xyz.shape[1], 1, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment