"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5e84353ebab5e0ce4fc762f64fabbdd9ac0c282a"
Commit 54b0a095 authored by rusty1s's avatar rusty1s
Browse files

update

parent 341f959a
...@@ -72,8 +72,8 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -72,8 +72,8 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
} }
n_ids.clear(); n_ids.clear();
std::unordered_map<int64_t, int64_t> n_id_map; std::map<int64_t, int64_t> n_id_map;
std::unordered_map<int64_t, int64_t>::iterator iter; std::map<int64_t, int64_t>::iterator iter;
int64_t i = 0; int64_t i = 0;
for (int64_t v : n_id_set) { for (int64_t v : n_id_set) {
...@@ -108,12 +108,12 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -108,12 +108,12 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
out_ptr_data[0] = 0; out_ptr_data[0] = 0;
int64_t node_cumsum = 0, edge_cumsum = 0; int64_t node_cumsum = 0, edge_cumsum = 0;
for (int64_t g = 0; g < idx.numel() - 1; g++) { for (int64_t g = 1; g < idx.numel(); g++) {
node_cumsum += out_n_ids[g].numel(); node_cumsum += out_n_ids[g - 1].numel();
edge_cumsum += out_cols[g].numel(); edge_cumsum += out_cols[g - 1].numel();
out_rowptrs[g + 2].add_(edge_cumsum); out_rowptrs[g + 1].add_(edge_cumsum);
out_cols[g + 1].add_(node_cumsum); out_cols[g].add_(node_cumsum);
out_ptr_data[g + 1] = node_cumsum; out_ptr_data[g] = node_cumsum;
} }
node_cumsum += out_n_ids[idx.numel() - 1].numel(); node_cumsum += out_n_ids[idx.numel() - 1].numel();
out_ptr_data[idx.numel()] = node_cumsum; out_ptr_data[idx.numel()] = node_cumsum;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment