Commit 1b956856 authored by rusty1s's avatar rusty1s
Browse files

fix e_id in sampling

parent 5931f37b
...@@ -19,15 +19,14 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -19,15 +19,14 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>(); auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
out_rowptr_data[0] = 0; out_rowptr_data[0] = 0;
std::vector<std::vector<int64_t>> cols; std::vector<std::vector<std::tuple<int64_t, int64_t>>> cols; // col, e_id
std::vector<int64_t> n_ids; std::vector<int64_t> n_ids;
std::unordered_map<int64_t, int64_t> n_id_map; std::unordered_map<int64_t, int64_t> n_id_map;
std::vector<int64_t> e_ids;
int64_t i; int64_t i;
for (int64_t n = 0; n < idx.numel(); n++) { for (int64_t n = 0; n < idx.numel(); n++) {
i = idx_data[n]; i = idx_data[n];
cols.push_back(std::vector<int64_t>()); cols.push_back(std::vector<std::tuple<int64_t, int64_t>>());
n_id_map[i] = n; n_id_map[i] = n;
n_ids.push_back(i); n_ids.push_back(i);
} }
...@@ -49,9 +48,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -49,9 +48,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
n_id_map[c] = n_ids.size(); n_id_map[c] = n_ids.size();
n_ids.push_back(c); n_ids.push_back(c);
} }
cols[i].push_back(std::make_tuple(n_id_map[c], e));
cols[i].push_back(n_id_map[c]);
e_ids.push_back(e);
} }
out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size(); out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
} }
...@@ -72,9 +69,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -72,9 +69,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
n_id_map[c] = n_ids.size(); n_id_map[c] = n_ids.size();
n_ids.push_back(c); n_ids.push_back(c);
} }
cols[i].push_back(std::make_tuple(n_id_map[c], e));
cols[i].push_back(n_id_map[c]);
e_ids.push_back(c);
} }
out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size(); out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
} }
...@@ -106,29 +101,34 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -106,29 +101,34 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
n_id_map[c] = n_ids.size(); n_id_map[c] = n_ids.size();
n_ids.push_back(c); n_ids.push_back(c);
} }
cols[i].push_back(std::make_tuple(n_id_map[c], e));
cols[i].push_back(n_id_map[c]);
e_ids.push_back(c);
} }
out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size(); out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
} }
} }
int64_t n_len = n_ids.size(), e_len = e_ids.size(); int64_t N = n_ids.size();
auto n_id = torch::from_blob(n_ids.data(), {n_len}, col.options()).clone(); auto out_n_id = torch::from_blob(n_ids.data(), {N}, col.options()).clone();
auto e_id = torch::from_blob(e_ids.data(), {e_len}, col.options()).clone();
auto out_col = torch::empty(e_len, col.options()); int64_t E = out_rowptr_data[idx.numel()];
auto out_col = torch::empty(E, col.options());
auto out_col_data = out_col.data_ptr<int64_t>(); auto out_col_data = out_col.data_ptr<int64_t>();
auto out_e_id = torch::empty(E, col.options());
auto out_e_id_data = out_e_id.data_ptr<int64_t>();
i = 0; i = 0;
for (std::vector<int64_t> &col_vec : cols) { for (std::vector<std::tuple<int64_t, int64_t>> &col_vec : cols) {
std::sort(col_vec.begin(), col_vec.end()); std::sort(col_vec.begin(), col_vec.end(),
for (const int64_t &c : col_vec) { [](const std::tuple<int64_t, int64_t> &a,
out_col_data[i] = c; const std::tuple<int64_t, int64_t> &b) -> bool {
return std::get<0>(a) < std::get<0>(b);
});
for (const std::tuple<int64_t, int64_t> &value : col_vec) {
out_col_data[i] = std::get<0>(value);
out_e_id_data[i] = std::get<1>(value);
i += 1; i += 1;
} }
} }
return std::make_tuple(out_rowptr, out_col, n_id, e_id); return std::make_tuple(out_rowptr, out_col, out_n_id, out_e_id);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment