Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
6b634203
Commit
6b634203
authored
May 27, 2025
by
limm
Browse files
support v1.6.3
parent
c2dcc5fd
Changes
64
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
8 deletions
+45
-8
torch_cluster/rw.py
torch_cluster/rw.py
+22
-7
torch_cluster/sampler.py
torch_cluster/sampler.py
+0
-1
torch_cluster/testing.py
torch_cluster/testing.py
+17
-0
torch_cluster/typing.py
torch_cluster/typing.py
+6
-0
No files found.
torch_cluster/rw.py
View file @
6b634203
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
,
Union
import
torch
from
torch
import
Tensor
@
torch
.
jit
.
script
def
random_walk
(
row
:
Tensor
,
col
:
Tensor
,
start
:
Tensor
,
walk_length
:
int
,
p
:
float
=
1
,
q
:
float
=
1
,
coalesced
:
bool
=
True
,
num_nodes
:
Optional
[
int
]
=
None
)
->
Tensor
:
def
random_walk
(
row
:
Tensor
,
col
:
Tensor
,
start
:
Tensor
,
walk_length
:
int
,
p
:
float
=
1
,
q
:
float
=
1
,
coalesced
:
bool
=
True
,
num_nodes
:
Optional
[
int
]
=
None
,
return_edge_indices
:
bool
=
False
,
)
->
Union
[
Tensor
,
Tuple
[
Tensor
,
Tensor
]]:
"""Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks"
...
...
@@ -28,6 +35,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
the graph given by :obj:`(row, col)` according to :obj:`row`.
(default: :obj:`True`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
return_edge_indices (bool, optional): Whether to additionally return
the indices of edges traversed during the random walk.
(default: :obj:`False`)
:rtype: :class:`LongTensor`
"""
...
...
@@ -43,5 +53,10 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
rowptr
=
row
.
new_zeros
(
num_nodes
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
rowptr
[
1
:])
return
torch
.
ops
.
torch_cluster
.
random_walk
(
rowptr
,
col
,
start
,
walk_length
,
p
,
q
)[
0
]
node_seq
,
edge_seq
=
torch
.
ops
.
torch_cluster
.
random_walk
(
rowptr
,
col
,
start
,
walk_length
,
p
,
q
)
if
return_edge_indices
:
return
node_seq
,
edge_seq
return
node_seq
torch_cluster/sampler.py
View file @
6b634203
import
torch
@
torch
.
jit
.
script
def
neighbor_sampler
(
start
:
torch
.
Tensor
,
rowptr
:
torch
.
Tensor
,
size
:
float
):
assert
not
start
.
is_cuda
...
...
t
est/utils
.py
→
t
orch_cluster/testing
.py
View file @
6b634203
from
typing
import
Any
import
torch
dtypes
=
[
torch
.
half
,
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
dtypes
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
grad_dtypes
=
[
torch
.
half
,
torch
.
float
,
torch
.
double
]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
devices
+=
[
torch
.
device
(
f
'cuda:
{
torch
.
cuda
.
current_device
()
}
'
)]
devices
+=
[
torch
.
device
(
'cuda:
0
'
)]
def
tensor
(
x
,
dtype
,
device
):
def
tensor
(
x
:
Any
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
return
None
if
x
is
None
else
torch
.
tensor
(
x
,
dtype
=
dtype
,
device
=
device
)
torch_cluster/typing.py
0 → 100644
View file @
6b634203
import
torch
try
:
WITH_PTR_LIST
=
hasattr
(
torch
.
ops
.
torch_cluster
,
'fps_ptr_list'
)
except
Exception
:
WITH_PTR_LIST
=
False
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment