Unverified Commit 80dea956 authored by g-morishita's avatar g-morishita Committed by GitHub
Browse files

improve infererence the number of nodes (#112)

parent 3521c0d9
......@@ -32,7 +32,7 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
:rtype: :class:`LongTensor`
"""
if num_nodes is None:
num_nodes = max(int(row.max()), int(col.max())) + 1
num_nodes = max(int(row.max()), int(col.max()), int(start.max())) + 1
if coalesced:
perm = torch.argsort(row * num_nodes + col)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment