Commit b9f0417d authored by rusty1s's avatar rusty1s
Browse files

remove rowcount from sample + correct col ordering

parent d1d4ec3c
......@@ -4,8 +4,8 @@
// Returns `rowptr`, `col`, `n_id`, `e_id`,
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
torch::Tensor idx, int64_t num_neighbors, bool replace) {
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t num_neighbors, bool replace) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(idx);
......@@ -13,33 +13,36 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto rowcount_data = rowcount.data_ptr<int64_t>();
auto idx_data = idx.data_ptr<int64_t>();
auto out_rowptr = torch::empty(idx.size(0) + 1, rowptr.options());
auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
out_rowptr_data[0] = 0;
std::vector<int64_t> cols;
std::vector<std::multiset<int64_t>> cols;
std::vector<int64_t> n_ids;
std::unordered_map<int64_t, int64_t> n_id_map;
std::vector<int64_t> e_ids;
int64_t i;
for (int64_t n = 0; n < idx.size(0); n++) {
for (int64_t n = 0; n < idx.numel(); n++) {
i = idx_data[n];
cols.push_back(std::multiset<int64_t>());
n_id_map[i] = n;
n_ids.push_back(i);
}
int64_t n, c, e, row_start, row_end, row_count;
if (num_neighbors < 0) { // No sampling ======================================
int64_t r, c, e, offset = 0;
for (int64_t i = 0; i < idx.size(0); i++) {
r = idx_data[i];
for (int64_t i = 0; i < idx.numel(); i++) {
n = idx_data[i];
row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
row_count = row_end - row_start;
for (int64_t j = 0; j < rowcount_data[r]; j++) {
e = rowptr_data[r] + j;
for (int64_t j = 0; j < row_count; j++) {
e = row_start + j;
c = col_data[e];
if (n_id_map.count(c) == 0) {
......@@ -47,22 +50,22 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
n_ids.push_back(c);
}
cols.push_back(n_id_map[c]);
cols[i].insert(n_id_map[c]);
e_ids.push_back(e);
}
offset = cols.size();
out_rowptr_data[i + 1] = offset;
out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
}
}
else if (replace) { // Sample with replacement ===============================
int64_t r, c, e, offset = 0;
for (int64_t i = 0; i < idx.size(0); i++) {
r = idx_data[i];
for (int64_t i = 0; i < idx.numel(); i++) {
n = idx_data[i];
row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
row_count = row_end - row_start;
for (int64_t j = 0; j < num_neighbors; j++) {
e = rowptr_data[r] + rand() % rowcount_data[r];
e = row_start + rand() % row_count;
c = col_data[e];
if (n_id_map.count(c) == 0) {
......@@ -70,38 +73,33 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
n_ids.push_back(c);
}
c = n_id_map[c];
if (std::find(cols.begin() + offset, cols.end(), c) == cols.end()) {
cols.push_back(c);
e_ids.push_back(e);
}
cols[i].insert(n_id_map[c]);
e_ids.push_back(c);
}
offset = cols.size();
out_rowptr_data[i + 1] = offset;
out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
}
} else { // Sample without replacement via Robert Floyd algorithm ============
int64_t r, c, e, rc, offset = 0;
for (int64_t i = 0; i < idx.size(0); i++) {
r = idx_data[i];
rc = rowcount_data[r];
for (int64_t i = 0; i < idx.numel(); i++) {
n = idx_data[i];
row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
row_count = row_end - row_start;
std::unordered_set<int64_t> perm;
if (rc <= num_neighbors) {
for (int64_t x = 0; x < rc; x++) {
perm.insert(x);
}
} else {
for (int64_t x = rc - std::min(rc, num_neighbors); x < rc; x++) {
if (!perm.insert(rand() % x).second) {
perm.insert(x);
}
if (row_count <= num_neighbors) {
for (int64_t j = 0; j < row_count; j++)
perm.insert(j);
} else { // See: https://www.nowherenearithaca.com/2013/05/
// robert-floyds-tiny-and-beautiful.html
for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
if (!perm.insert(rand() % j).second)
perm.insert(j);
}
}
for (const int64_t &p : perm) {
e = rowptr_data[r] + p;
e = row_start + p;
c = col_data[e];
if (n_id_map.count(c) == 0) {
......@@ -109,18 +107,27 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
n_ids.push_back(c);
}
cols.push_back(n_id_map[c]);
e_ids.push_back(e);
cols[i].insert(n_id_map[c]);
e_ids.push_back(c);
}
offset = cols.size();
out_rowptr_data[i + 1] = offset;
out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
}
}
int64_t n_len = n_ids.size(), e_len = cols.size();
col = torch::from_blob(cols.data(), {e_len}, col.options()).clone();
int64_t n_len = n_ids.size(), e_len = e_ids.size();
auto n_id = torch::from_blob(n_ids.data(), {n_len}, col.options()).clone();
auto e_id = torch::from_blob(e_ids.data(), {e_len}, col.options()).clone();
return std::make_tuple(out_rowptr, col, n_id, e_id);
auto out_col = torch::empty(e_len, col.options());
auto out_col_data = out_col.data_ptr<int64_t>();
i = 0;
for (const std::multiset<int64_t> &col_set : cols) {
for (const int64_t &c : col_set) {
out_col_data[i] = c;
i += 1;
}
}
return std::make_tuple(out_rowptr, out_col, n_id, e_id);
}
......@@ -3,5 +3,5 @@
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
torch::Tensor idx, int64_t num_neighbors, bool replace);
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t num_neighbors, bool replace);
......@@ -8,8 +8,8 @@ PyMODINIT_FUNC PyInit__sample(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
torch::Tensor idx, int64_t num_neighbors, bool replace) {
sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t num_neighbors, bool replace) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
......@@ -17,7 +17,7 @@ sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return sample_adj_cpu(rowptr, col, rowcount, idx, num_neighbors, replace);
return sample_adj_cpu(rowptr, col, idx, num_neighbors, replace);
}
}
......
import torch
from torch_sparse import SparseTensor, sample_adj
def test_sample_adj():
row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5])
col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])
value = torch.arange(row.size(0))
adj_t = SparseTensor(row=row, col=col, value=value, sparse_sizes=(6, 6))
out, n_id = sample_adj(adj_t, torch.arange(2, 6), num_neighbors=-1)
assert n_id.tolist() == [2, 3, 4, 5, 0, 1]
row, col, val = out.coo()
assert row.tolist() == [0, 0, 0, 0, 1, 2, 2, 3, 3]
assert col.tolist() == [2, 3, 4, 5, 4, 0, 3, 0, 2]
assert val.tolist() == [5, 6, 7, 8, 9, 10, 11, 12, 13]
......@@ -26,10 +26,9 @@ def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int,
replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]:
rowptr, col, value = src.csr()
rowcount = src.storage.rowcount()
rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj(
rowptr, col, rowcount, subset, num_neighbors, replace)
rowptr, col, subset, num_neighbors, replace)
if value is not None:
value = value[e_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment