Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
5d12a68a
Commit
5d12a68a
authored
Sep 25, 2020
by
rusty1s
Browse files
fix rw isolated nodes bug
parent
2bf5e763
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
9 deletions
+25
-9
csrc/cpu/rw_cpu.cpp
csrc/cpu/rw_cpu.cpp
+9
-5
csrc/cuda/rw_cuda.cu
csrc/cuda/rw_cuda.cu
+7
-3
test/test_rw.py
test/test_rw.py
+9
-1
No files found.
csrc/cpu/rw_cpu.cpp
View file @
5d12a68a
...
@@ -16,7 +16,7 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -16,7 +16,7 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto
rand
=
torch
::
rand
({
start
.
size
(
0
),
walk_length
},
auto
rand
=
torch
::
rand
({
start
.
size
(
0
),
walk_length
},
start
.
options
().
dtype
(
torch
::
kFloat
));
start
.
options
().
dtype
(
torch
::
kFloat
));
auto
out
=
torch
::
full
({
start
.
size
(
0
),
walk_length
+
1
},
-
1
,
start
.
options
());
auto
out
=
torch
::
empty
({
start
.
size
(
0
),
walk_length
+
1
},
start
.
options
());
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
...
@@ -29,12 +29,16 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -29,12 +29,16 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto
offset
=
n
*
(
walk_length
+
1
);
auto
offset
=
n
*
(
walk_length
+
1
);
out_data
[
offset
]
=
cur
;
out_data
[
offset
]
=
cur
;
int64_t
row_start
,
row_end
;
int64_t
row_start
,
row_end
,
rnd
;
for
(
auto
l
=
1
;
l
<=
walk_length
;
l
++
)
{
for
(
auto
l
=
1
;
l
<=
walk_length
;
l
++
)
{
row_start
=
rowptr_data
[
cur
],
row_end
=
rowptr_data
[
cur
+
1
];
row_start
=
rowptr_data
[
cur
],
row_end
=
rowptr_data
[
cur
+
1
];
if
(
row_end
-
row_start
==
0
)
{
cur
=
col_data
[
row_start
+
int64_t
(
rand_data
[
n
*
walk_length
+
(
l
-
1
)]
*
cur
=
n
;
(
row_end
-
row_start
))];
}
else
{
rnd
=
int64_t
(
rand_data
[
n
*
walk_length
+
(
l
-
1
)]
*
(
row_end
-
row_start
));
cur
=
col_data
[
row_start
+
rnd
];
}
out_data
[
offset
+
l
]
=
cur
;
out_data
[
offset
+
l
]
=
cur
;
}
}
}
}
...
...
csrc/cuda/rw_cuda.cu
View file @
5d12a68a
...
@@ -23,10 +23,14 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr,
...
@@ -23,10 +23,14 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr,
cur
=
out
[
i
];
cur
=
out
[
i
];
row_start
=
rowptr
[
cur
],
row_end
=
rowptr
[
cur
+
1
];
row_start
=
rowptr
[
cur
],
row_end
=
rowptr
[
cur
+
1
];
if
(
row_end
-
row_start
==
0
)
{
out
[
l
*
numel
+
thread_idx
]
=
cur
;
}
else
{
out
[
l
*
numel
+
thread_idx
]
=
out
[
l
*
numel
+
thread_idx
]
=
col
[
row_start
+
int64_t
(
rand
[
i
]
*
(
row_end
-
row_start
))];
col
[
row_start
+
int64_t
(
rand
[
i
]
*
(
row_end
-
row_start
))];
}
}
}
}
}
}
}
torch
::
Tensor
random_walk_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
random_walk_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
...
@@ -43,7 +47,7 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
...
@@ -43,7 +47,7 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
auto
rand
=
torch
::
rand
({
start
.
size
(
0
),
walk_length
},
auto
rand
=
torch
::
rand
({
start
.
size
(
0
),
walk_length
},
start
.
options
().
dtype
(
torch
::
kFloat
));
start
.
options
().
dtype
(
torch
::
kFloat
));
auto
out
=
torch
::
full
({
walk_length
+
1
,
start
.
size
(
0
)},
-
1
,
start
.
options
());
auto
out
=
torch
::
empty
({
walk_length
+
1
,
start
.
size
(
0
)},
start
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
uniform_random_walk_kernel
<<<
BLOCKS
(
start
.
numel
()),
THREADS
,
0
,
stream
>>>
(
uniform_random_walk_kernel
<<<
BLOCKS
(
start
.
numel
()),
THREADS
,
0
,
stream
>>>
(
...
...
test/test_rw.py
View file @
5d12a68a
...
@@ -12,7 +12,7 @@ def test_rw(device):
...
@@ -12,7 +12,7 @@ def test_rw(device):
start
=
tensor
([
0
,
1
,
2
,
3
,
4
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
,
3
,
4
],
torch
.
long
,
device
)
walk_length
=
10
walk_length
=
10
out
=
random_walk
(
row
,
col
,
start
,
walk_length
,
coalesced
=
True
)
out
=
random_walk
(
row
,
col
,
start
,
walk_length
)
assert
out
[:,
0
].
tolist
()
==
start
.
tolist
()
assert
out
[:,
0
].
tolist
()
==
start
.
tolist
()
for
n
in
range
(
start
.
size
(
0
)):
for
n
in
range
(
start
.
size
(
0
)):
...
@@ -20,3 +20,11 @@ def test_rw(device):
...
@@ -20,3 +20,11 @@ def test_rw(device):
for
i
in
range
(
1
,
walk_length
):
for
i
in
range
(
1
,
walk_length
):
assert
out
[
n
,
i
].
item
()
in
col
[
row
==
cur
].
tolist
()
assert
out
[
n
,
i
].
item
()
in
col
[
row
==
cur
].
tolist
()
cur
=
out
[
n
,
i
].
item
()
cur
=
out
[
n
,
i
].
item
()
row
=
tensor
([
0
,
1
],
torch
.
long
,
device
)
col
=
tensor
([
1
,
0
],
torch
.
long
,
device
)
start
=
tensor
([
0
,
1
,
2
],
torch
.
long
,
device
)
walk_length
=
4
out
=
random_walk
(
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
)
assert
out
.
tolist
()
==
[[
0
,
1
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
2
,
2
,
2
,
2
]]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment