Commit b9f0417d authored by rusty1s's avatar rusty1s
Browse files

remove rowcount from sample + correct col ordering

parent d1d4ec3c
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
// Returns `rowptr`, `col`, `n_id`, `e_id`, // Returns `rowptr`, `col`, `n_id`, `e_id`,
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount, sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
torch::Tensor idx, int64_t num_neighbors, bool replace) { int64_t num_neighbors, bool replace) {
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
CHECK_CPU(idx); CHECK_CPU(idx);
...@@ -13,33 +13,36 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount, ...@@ -13,33 +13,36 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
auto rowptr_data = rowptr.data_ptr<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>(); auto col_data = col.data_ptr<int64_t>();
auto rowcount_data = rowcount.data_ptr<int64_t>();
auto idx_data = idx.data_ptr<int64_t>(); auto idx_data = idx.data_ptr<int64_t>();
auto out_rowptr = torch::empty(idx.size(0) + 1, rowptr.options()); auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>(); auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
out_rowptr_data[0] = 0; out_rowptr_data[0] = 0;
std::vector<int64_t> cols; std::vector<std::multiset<int64_t>> cols;
std::vector<int64_t> n_ids; std::vector<int64_t> n_ids;
std::unordered_map<int64_t, int64_t> n_id_map; std::unordered_map<int64_t, int64_t> n_id_map;
std::vector<int64_t> e_ids; std::vector<int64_t> e_ids;
int64_t i; int64_t i;
for (int64_t n = 0; n < idx.size(0); n++) { for (int64_t n = 0; n < idx.numel(); n++) {
i = idx_data[n]; i = idx_data[n];
cols.push_back(std::multiset<int64_t>());
n_id_map[i] = n; n_id_map[i] = n;
n_ids.push_back(i); n_ids.push_back(i);
} }
int64_t n, c, e, row_start, row_end, row_count;
if (num_neighbors < 0) { // No sampling ====================================== if (num_neighbors < 0) { // No sampling ======================================
int64_t r, c, e, offset = 0; for (int64_t i = 0; i < idx.numel(); i++) {
for (int64_t i = 0; i < idx.size(0); i++) { n = idx_data[i];
r = idx_data[i]; row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
row_count = row_end - row_start;
for (int64_t j = 0; j < rowcount_data[r]; j++) { for (int64_t j = 0; j < row_count; j++) {
e = rowptr_data[r] + j; e = row_start + j;
c = col_data[e]; c = col_data[e];
if (n_id_map.count(c) == 0) { if (n_id_map.count(c) == 0) {
...@@ -47,22 +50,22 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount, ...@@ -47,22 +50,22 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
n_ids.push_back(c); n_ids.push_back(c);
} }
cols.push_back(n_id_map[c]); cols[i].insert(n_id_map[c]);
e_ids.push_back(e); e_ids.push_back(e);
} }
offset = cols.size(); out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
out_rowptr_data[i + 1] = offset;
} }
} }
else if (replace) { // Sample with replacement =============================== else if (replace) { // Sample with replacement ===============================
int64_t r, c, e, offset = 0; for (int64_t i = 0; i < idx.numel(); i++) {
for (int64_t i = 0; i < idx.size(0); i++) { n = idx_data[i];
r = idx_data[i]; row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
row_count = row_end - row_start;
for (int64_t j = 0; j < num_neighbors; j++) { for (int64_t j = 0; j < num_neighbors; j++) {
e = rowptr_data[r] + rand() % rowcount_data[r]; e = row_start + rand() % row_count;
c = col_data[e]; c = col_data[e];
if (n_id_map.count(c) == 0) { if (n_id_map.count(c) == 0) {
...@@ -70,38 +73,33 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount, ...@@ -70,38 +73,33 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
n_ids.push_back(c); n_ids.push_back(c);
} }
c = n_id_map[c]; cols[i].insert(n_id_map[c]);
if (std::find(cols.begin() + offset, cols.end(), c) == cols.end()) { e_ids.push_back(c);
cols.push_back(c);
e_ids.push_back(e);
}
} }
offset = cols.size(); out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
out_rowptr_data[i + 1] = offset;
} }
} else { // Sample without replacement via Robert Floyd algorithm ============ } else { // Sample without replacement via Robert Floyd algorithm ============
int64_t r, c, e, rc, offset = 0; for (int64_t i = 0; i < idx.numel(); i++) {
for (int64_t i = 0; i < idx.size(0); i++) { n = idx_data[i];
r = idx_data[i]; row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
rc = rowcount_data[r]; row_count = row_end - row_start;
std::unordered_set<int64_t> perm; std::unordered_set<int64_t> perm;
if (rc <= num_neighbors) { if (row_count <= num_neighbors) {
for (int64_t x = 0; x < rc; x++) { for (int64_t j = 0; j < row_count; j++)
perm.insert(x); perm.insert(j);
} } else { // See: https://www.nowherenearithaca.com/2013/05/
} else { // robert-floyds-tiny-and-beautiful.html
for (int64_t x = rc - std::min(rc, num_neighbors); x < rc; x++) { for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
if (!perm.insert(rand() % x).second) { if (!perm.insert(rand() % j).second)
perm.insert(x); perm.insert(j);
}
} }
} }
for (const int64_t &p : perm) { for (const int64_t &p : perm) {
e = rowptr_data[r] + p; e = row_start + p;
c = col_data[e]; c = col_data[e];
if (n_id_map.count(c) == 0) { if (n_id_map.count(c) == 0) {
...@@ -109,18 +107,27 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount, ...@@ -109,18 +107,27 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
n_ids.push_back(c); n_ids.push_back(c);
} }
cols.push_back(n_id_map[c]); cols[i].insert(n_id_map[c]);
e_ids.push_back(e); e_ids.push_back(c);
} }
offset = cols.size(); out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
out_rowptr_data[i + 1] = offset;
} }
} }
int64_t n_len = n_ids.size(), e_len = cols.size(); int64_t n_len = n_ids.size(), e_len = e_ids.size();
col = torch::from_blob(cols.data(), {e_len}, col.options()).clone();
auto n_id = torch::from_blob(n_ids.data(), {n_len}, col.options()).clone(); auto n_id = torch::from_blob(n_ids.data(), {n_len}, col.options()).clone();
auto e_id = torch::from_blob(e_ids.data(), {e_len}, col.options()).clone(); auto e_id = torch::from_blob(e_ids.data(), {e_len}, col.options()).clone();
return std::make_tuple(out_rowptr, col, n_id, e_id); auto out_col = torch::empty(e_len, col.options());
auto out_col_data = out_col.data_ptr<int64_t>();
i = 0;
for (const std::multiset<int64_t> &col_set : cols) {
for (const int64_t &c : col_set) {
out_col_data[i] = c;
i += 1;
}
}
return std::make_tuple(out_rowptr, out_col, n_id, e_id);
} }
...@@ -3,5 +3,5 @@ ...@@ -3,5 +3,5 @@
#include <torch/extension.h> #include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount, sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
torch::Tensor idx, int64_t num_neighbors, bool replace); int64_t num_neighbors, bool replace);
...@@ -8,8 +8,8 @@ PyMODINIT_FUNC PyInit__sample(void) { return NULL; } ...@@ -8,8 +8,8 @@ PyMODINIT_FUNC PyInit__sample(void) { return NULL; }
#endif #endif
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount, sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
torch::Tensor idx, int64_t num_neighbors, bool replace) { int64_t num_neighbors, bool replace) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
AT_ERROR("No CUDA version supported"); AT_ERROR("No CUDA version supported");
...@@ -17,7 +17,7 @@ sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount, ...@@ -17,7 +17,7 @@ sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
} else { } else {
return sample_adj_cpu(rowptr, col, rowcount, idx, num_neighbors, replace); return sample_adj_cpu(rowptr, col, idx, num_neighbors, replace);
} }
} }
......
import torch
from torch_sparse import SparseTensor, sample_adj
def test_sample_adj():
row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5])
col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])
value = torch.arange(row.size(0))
adj_t = SparseTensor(row=row, col=col, value=value, sparse_sizes=(6, 6))
out, n_id = sample_adj(adj_t, torch.arange(2, 6), num_neighbors=-1)
assert n_id.tolist() == [2, 3, 4, 5, 0, 1]
row, col, val = out.coo()
assert row.tolist() == [0, 0, 0, 0, 1, 2, 2, 3, 3]
assert col.tolist() == [2, 3, 4, 5, 4, 0, 3, 0, 2]
assert val.tolist() == [5, 6, 7, 8, 9, 10, 11, 12, 13]
...@@ -26,10 +26,9 @@ def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int, ...@@ -26,10 +26,9 @@ def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int,
replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]: replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
rowcount = src.storage.rowcount()
rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj( rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj(
rowptr, col, rowcount, subset, num_neighbors, replace) rowptr, col, subset, num_neighbors, replace)
if value is not None: if value is not None:
value = value[e_id] value = value[e_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment