Commit 60836e2e authored by rusty1s's avatar rusty1s
Browse files

remove node id

parent 3369b5f0
...@@ -3,10 +3,8 @@ ...@@ -3,10 +3,8 @@
#include <TH/THGenerator.hpp> #include <TH/THGenerator.hpp>
std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start, at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
at::Tensor cumdeg, float factor) {
at::Tensor col, size_t size,
float factor) {
THGenerator *generator = THGenerator_new(); THGenerator *generator = THGenerator_new();
auto start_ptr = start.data<int64_t>(); auto start_ptr = start.data<int64_t>();
...@@ -46,9 +44,7 @@ std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start, ...@@ -46,9 +44,7 @@ std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start,
int64_t len = e_ids.size(); int64_t len = e_ids.size();
auto e_id = torch::from_blob(e_ids.data(), {len}, start.options()).clone(); auto e_id = torch::from_blob(e_ids.data(), {len}, start.options()).clone();
auto n_id = std::get<0>(at::_unique(col.index_select(0, e_id))); return e_id;
return std::make_tuple(n_id, e_id);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -8,12 +8,9 @@ def test_neighbor_sampler(): ...@@ -8,12 +8,9 @@ def test_neighbor_sampler():
start = torch.tensor([0, 1]) start = torch.tensor([0, 1])
cumdeg = torch.tensor([0, 3, 7]) cumdeg = torch.tensor([0, 3, 7])
col = torch.tensor([1, 2, 3, 0, 2, 3, 4])
n_id, e_id = neighbor_sampler(start, cumdeg, col, size=1.0) e_id = neighbor_sampler(start, cumdeg, size=1.0)
assert n_id.tolist() == [0, 1, 2, 3, 4]
assert e_id.tolist() == [0, 2, 1, 5, 6, 3, 4] assert e_id.tolist() == [0, 2, 1, 5, 6, 3, 4]
n_id, e_id = neighbor_sampler(start, cumdeg, col, size=3) e_id = neighbor_sampler(start, cumdeg, size=3)
assert n_id.tolist() == [1, 2, 3, 4]
assert e_id.tolist() == [1, 0, 2, 4, 5, 6] assert e_id.tolist() == [1, 0, 2, 4, 5, 6]
import torch_cluster.sampler_cpu import torch_cluster.sampler_cpu
def neighbor_sampler(start, cumdeg, col, size): def neighbor_sampler(start, cumdeg, size):
assert not start.is_cuda assert not start.is_cuda
factor = 1 factor = 1
...@@ -10,4 +10,4 @@ def neighbor_sampler(start, cumdeg, col, size): ...@@ -10,4 +10,4 @@ def neighbor_sampler(start, cumdeg, col, size):
size = 2147483647 size = 2147483647
op = torch_cluster.sampler_cpu.neighbor_sampler op = torch_cluster.sampler_cpu.neighbor_sampler
return op(start, cumdeg, col, size, factor) return op(start, cumdeg, size, factor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment