Commit 6b634203 authored by limm's avatar limm
Browse files

support v1.6.3

parent c2dcc5fd
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
@torch.jit.script def random_walk(
def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, row: Tensor,
p: float = 1, q: float = 1, coalesced: bool = True, col: Tensor,
num_nodes: Optional[int] = None) -> Tensor: start: Tensor,
walk_length: int,
p: float = 1,
q: float = 1,
coalesced: bool = True,
num_nodes: Optional[int] = None,
return_edge_indices: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Samples random walks of length :obj:`walk_length` from all node indices """Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks" `"node2vec: Scalable Feature Learning for Networks"
...@@ -28,6 +35,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, ...@@ -28,6 +35,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
the graph given by :obj:`(row, col)` according to :obj:`row`. the graph given by :obj:`(row, col)` according to :obj:`row`.
(default: :obj:`True`) (default: :obj:`True`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
return_edge_indices (bool, optional): Whether to additionally return
the indices of edges traversed during the random walk.
(default: :obj:`False`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
""" """
...@@ -43,5 +53,10 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, ...@@ -43,5 +53,10 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
rowptr = row.new_zeros(num_nodes + 1) rowptr = row.new_zeros(num_nodes + 1)
torch.cumsum(deg, 0, out=rowptr[1:]) torch.cumsum(deg, 0, out=rowptr[1:])
return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length, node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
p, q)[0] rowptr, col, start, walk_length, p, q)
if return_edge_indices:
return node_seq, edge_seq
return node_seq
import torch import torch
@torch.jit.script
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float): def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
assert not start.is_cuda assert not start.is_cuda
......
from typing import Any
import torch import torch
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long] dtypes = [
torch.half, torch.bfloat16, torch.float, torch.double, torch.int,
torch.long
]
grad_dtypes = [torch.half, torch.float, torch.double] grad_dtypes = [torch.half, torch.float, torch.double]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')] devices += [torch.device('cuda:0')]
def tensor(x, dtype, device): def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device) return None if x is None else torch.tensor(x, dtype=dtype, device=device)
import torch
try:
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
except Exception:
WITH_PTR_LIST = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment