Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
d7a16f9a
Commit
d7a16f9a
authored
Oct 27, 2020
by
rusty1s
Browse files
return both node and edge ids
parent
2dd14df1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
25 additions
and
24 deletions
+25
-24
csrc/cluster.h
csrc/cluster.h
+3
-3
csrc/cpu/rw_cpu.cpp
csrc/cpu/rw_cpu.cpp
+4
-4
csrc/cpu/rw_cpu.h
csrc/cpu/rw_cpu.h
+3
-3
csrc/cuda/rw_cuda.cu
csrc/cuda/rw_cuda.cu
+4
-4
csrc/cuda/rw_cuda.h
csrc/cuda/rw_cuda.h
+3
-3
csrc/rw.cpp
csrc/rw.cpp
+3
-3
torch_cluster/rw.py
torch_cluster/rw.py
+5
-4
No files found.
csrc/cluster.h
View file @
d7a16f9a
...
@@ -23,9 +23,9 @@ torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
...
@@ -23,9 +23,9 @@ torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch
::
Tensor
radius
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
Tensor
ptr_x
,
torch
::
Tensor
radius
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
Tensor
ptr_x
,
torch
::
Tensor
ptr_y
,
double
r
,
int64_t
max_num_neighbors
);
torch
::
Tensor
ptr_y
,
double
r
,
int64_t
max_num_neighbors
);
torch
::
Tensor
random_walk
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
torch
::
Tensor
start
,
int64_t
walk_length
,
double
p
,
random_walk
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
double
q
);
int64_t
walk_length
,
double
p
,
double
q
);
torch
::
Tensor
neighbor_sampler
(
torch
::
Tensor
start
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
neighbor_sampler
(
torch
::
Tensor
start
,
torch
::
Tensor
rowptr
,
int64_t
count
,
double
factor
);
int64_t
count
,
double
factor
);
csrc/cpu/rw_cpu.cpp
View file @
d7a16f9a
...
@@ -4,9 +4,9 @@
...
@@ -4,9 +4,9 @@
#include "utils.h"
#include "utils.h"
torch
::
Tensor
random_walk_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
torch
::
Tensor
start
,
int64_t
walk_length
,
random_walk_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
double
p
,
double
q
)
{
int64_t
walk_length
,
double
p
,
double
q
)
{
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
col
);
CHECK_CPU
(
start
);
CHECK_CPU
(
start
);
...
@@ -50,5 +50,5 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -50,5 +50,5 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
}
}
});
});
return
n_out
;
return
std
::
make_tuple
(
n_out
,
e_out
)
;
}
}
csrc/cpu/rw_cpu.h
View file @
d7a16f9a
...
@@ -2,6 +2,6 @@
...
@@ -2,6 +2,6 @@
#include <torch/extension.h>
#include <torch/extension.h>
torch
::
Tensor
random_walk_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
torch
::
Tensor
start
,
int64_t
walk_length
,
random_walk_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
double
p
,
double
q
);
int64_t
walk_length
,
double
p
,
double
q
);
csrc/cuda/rw_cuda.cu
View file @
d7a16f9a
...
@@ -35,9 +35,9 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr,
...
@@ -35,9 +35,9 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr,
}
}
}
}
torch
::
Tensor
random_walk_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
torch
::
Tensor
start
,
int64_t
walk_length
,
random_walk_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
double
p
,
double
q
)
{
int64_t
walk_length
,
double
p
,
double
q
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
start
);
CHECK_CUDA
(
start
);
...
@@ -60,5 +60,5 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
...
@@ -60,5 +60,5 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
n_out
.
data_ptr
<
int64_t
>
(),
e_out
.
data_ptr
<
int64_t
>
(),
walk_length
,
n_out
.
data_ptr
<
int64_t
>
(),
e_out
.
data_ptr
<
int64_t
>
(),
walk_length
,
start
.
numel
());
start
.
numel
());
return
n_out
.
t
().
contiguous
();
return
std
::
make_tuple
(
n_out
.
t
().
contiguous
()
,
e_out
.
t
().
contiguous
())
;
}
}
csrc/cuda/rw_cuda.h
View file @
d7a16f9a
...
@@ -2,6 +2,6 @@
...
@@ -2,6 +2,6 @@
#include <torch/extension.h>
#include <torch/extension.h>
torch
::
Tensor
random_walk_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
torch
::
Tensor
start
,
int64_t
walk_length
,
random_walk_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
double
p
,
double
q
);
int64_t
walk_length
,
double
p
,
double
q
);
csrc/rw.cpp
View file @
d7a16f9a
...
@@ -11,9 +11,9 @@
...
@@ -11,9 +11,9 @@
PyMODINIT_FUNC
PyInit__rw
(
void
)
{
return
NULL
;
}
PyMODINIT_FUNC
PyInit__rw
(
void
)
{
return
NULL
;
}
#endif
#endif
torch
::
Tensor
random_walk
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
torch
::
Tensor
start
,
int64_t
walk_length
,
double
p
,
random_walk
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
double
q
)
{
int64_t
walk_length
,
double
p
,
double
q
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
return
random_walk_cuda
(
rowptr
,
col
,
start
,
walk_length
,
p
,
q
);
return
random_walk_cuda
(
rowptr
,
col
,
start
,
walk_length
,
p
,
q
);
...
...
torch_cluster/rw.py
View file @
d7a16f9a
...
@@ -2,12 +2,13 @@ import warnings
...
@@ -2,12 +2,13 @@ import warnings
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
from
torch
import
Tensor
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
random_walk
(
row
:
torch
.
Tensor
,
col
:
torch
.
Tensor
,
start
:
torch
.
Tensor
,
def
random_walk
(
row
:
Tensor
,
col
:
Tensor
,
start
:
Tensor
,
walk_length
:
int
,
walk_length
:
int
,
p
:
float
=
1
,
q
:
float
=
1
,
p
:
float
=
1
,
q
:
float
=
1
,
coalesced
:
bool
=
True
,
coalesced
:
bool
=
True
,
num_nodes
:
Optional
[
int
]
=
None
):
num_nodes
:
Optional
[
int
]
=
None
)
->
Tensor
:
"""Samples random walks of length :obj:`walk_length` from all node indices
"""Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks"
`"node2vec: Scalable Feature Learning for Networks"
...
@@ -49,4 +50,4 @@ def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
...
@@ -49,4 +50,4 @@ def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
p
=
q
=
1.
p
=
q
=
1.
return
torch
.
ops
.
torch_cluster
.
random_walk
(
rowptr
,
col
,
start
,
walk_length
,
return
torch
.
ops
.
torch_cluster
.
random_walk
(
rowptr
,
col
,
start
,
walk_length
,
p
,
q
)
p
,
q
)
[
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment