Commit 79f67548 authored by rusty1s's avatar rusty1s
Browse files

mask

parent 7391824f
...@@ -69,9 +69,8 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32): ...@@ -69,9 +69,8 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
col = [torch.tensor(c) for c in col] col = [torch.tensor(c) for c in col]
row = [torch.full_like(c, i) for i, c in enumerate(col)] row = [torch.full_like(c, i) for i, c in enumerate(col)]
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0) row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
row = row[col<tree.n] mask = col < tree.n
col = col[col<tree.n] return torch.stack([row[mask], col[mask]], dim=0)
return torch.stack([row, col], dim=0)
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32): def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment