Commit 5931f37b authored by rusty1s's avatar rusty1s
Browse files

slightly faster sampling

parent fe35bc61
...@@ -19,7 +19,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -19,7 +19,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>(); auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
out_rowptr_data[0] = 0; out_rowptr_data[0] = 0;
std::vector<std::multiset<int64_t>> cols; std::vector<std::vector<int64_t>> cols;
std::vector<int64_t> n_ids; std::vector<int64_t> n_ids;
std::unordered_map<int64_t, int64_t> n_id_map; std::unordered_map<int64_t, int64_t> n_id_map;
std::vector<int64_t> e_ids; std::vector<int64_t> e_ids;
...@@ -27,7 +27,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -27,7 +27,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t i; int64_t i;
for (int64_t n = 0; n < idx.numel(); n++) { for (int64_t n = 0; n < idx.numel(); n++) {
i = idx_data[n]; i = idx_data[n];
cols.push_back(std::multiset<int64_t>()); cols.push_back(std::vector<int64_t>());
n_id_map[i] = n; n_id_map[i] = n;
n_ids.push_back(i); n_ids.push_back(i);
} }
...@@ -50,7 +50,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -50,7 +50,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
n_ids.push_back(c); n_ids.push_back(c);
} }
cols[i].insert(n_id_map[c]); cols[i].push_back(n_id_map[c]);
e_ids.push_back(e); e_ids.push_back(e);
} }
out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size(); out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
...@@ -73,7 +73,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -73,7 +73,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
n_ids.push_back(c); n_ids.push_back(c);
} }
cols[i].insert(n_id_map[c]); cols[i].push_back(n_id_map[c]);
e_ids.push_back(c); e_ids.push_back(c);
} }
out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size(); out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
...@@ -107,7 +107,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -107,7 +107,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
n_ids.push_back(c); n_ids.push_back(c);
} }
cols[i].insert(n_id_map[c]); cols[i].push_back(n_id_map[c]);
e_ids.push_back(c); e_ids.push_back(c);
} }
out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size(); out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
...@@ -122,8 +122,9 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -122,8 +122,9 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
auto out_col_data = out_col.data_ptr<int64_t>(); auto out_col_data = out_col.data_ptr<int64_t>();
i = 0; i = 0;
for (const std::multiset<int64_t> &col_set : cols) { for (std::vector<int64_t> &col_vec : cols) {
for (const int64_t &c : col_set) { std::sort(col_vec.begin(), col_vec.end());
for (const int64_t &c : col_vec) {
out_col_data[i] = c; out_col_data[i] = c;
i += 1; i += 1;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment