Commit a72001c5 authored by rusty1s's avatar rusty1s
Browse files

fix broadcasting

parent 19a5fc17
...@@ -15,7 +15,9 @@ def sample(src: SparseTensor, num_neighbors: int, ...@@ -15,7 +15,9 @@ def sample(src: SparseTensor, num_neighbors: int,
rowptr = rowptr[subset] rowptr = rowptr[subset]
rand = torch.rand((rowcount.size(0), num_neighbors), device=col.device) rand = torch.rand((rowcount.size(0), num_neighbors), device=col.device)
rand = rand.mul_(rowcount.to(rand.dtype)).to(torch.long).add_(rowptr) rand.mul_(rowcount.to(rand.dtype).view(-1, 1))
rand = rand.to(torch.long)
rand.add_(rowptr.view(-1, 1))
return col[rand] return col[rand]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment