Commit 6b634203 authored by limm's avatar limm
Browse files

support v1.6.3

parent c2dcc5fd
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from torch import Tensor
@torch.jit.script
def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
p: float = 1, q: float = 1, coalesced: bool = True,
num_nodes: Optional[int] = None) -> Tensor:
def random_walk(
row: Tensor,
col: Tensor,
start: Tensor,
walk_length: int,
p: float = 1,
q: float = 1,
coalesced: bool = True,
num_nodes: Optional[int] = None,
return_edge_indices: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks"
......@@ -28,6 +35,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
the graph given by :obj:`(row, col)` according to :obj:`row`.
(default: :obj:`True`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
return_edge_indices (bool, optional): Whether to additionally return
the indices of edges traversed during the random walk.
(default: :obj:`False`)
:rtype: :class:`LongTensor`
"""
......@@ -43,5 +53,10 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
rowptr = row.new_zeros(num_nodes + 1)
torch.cumsum(deg, 0, out=rowptr[1:])
return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
p, q)[0]
node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
rowptr, col, start, walk_length, p, q)
if return_edge_indices:
return node_seq, edge_seq
return node_seq
import torch
@torch.jit.script
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
assert not start.is_cuda
......
from typing import Any
import torch
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
dtypes = [
torch.half, torch.bfloat16, torch.float, torch.double, torch.int,
torch.long
]
grad_dtypes = [torch.half, torch.float, torch.double]
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
devices += [torch.device('cuda:0')]
def tensor(x, dtype, device):
def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device)
import torch
try:
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
except Exception:
WITH_PTR_LIST = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment