ego_sample_cpu.cpp 4.34 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
#include "ego_sample_cpu.h"

#include <ATen/Parallel.h>

#include "utils.h"

rusty1s's avatar
rusty1s committed
7
8
9
10
#ifdef _WIN32
#include <process.h>
#endif

rusty1s's avatar
rusty1s committed
11
12
13
14
inline torch::Tensor vec2tensor(std::vector<int64_t> vec) {
  return torch::from_blob(vec.data(), {(int64_t)vec.size()}, at::kLong).clone();
}

rusty1s's avatar
rusty1s committed
15
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`, `root_n_id`
rusty1s's avatar
rusty1s committed
16
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
rusty1s's avatar
rusty1s committed
17
           torch::Tensor, torch::Tensor>
rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
24
25
ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
                         torch::Tensor idx, int64_t depth,
                         int64_t num_neighbors, bool replace) {

  std::vector<torch::Tensor> out_rowptrs(idx.numel() + 1);
  std::vector<torch::Tensor> out_cols(idx.numel());
  std::vector<torch::Tensor> out_n_ids(idx.numel());
  std::vector<torch::Tensor> out_e_ids(idx.numel());
rusty1s's avatar
rusty1s committed
26
  auto out_root_n_id = torch::empty({idx.numel()}, at::kLong);
rusty1s's avatar
rusty1s committed
27
28
29
30
31
  out_rowptrs[0] = torch::zeros({1}, at::kLong);

  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto idx_data = idx.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
32
  auto out_root_n_id_data = out_root_n_id.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

  at::parallel_for(0, idx.numel(), 1, [&](int64_t begin, int64_t end) {
    int64_t row_start, row_end, row_count, vec_start, vec_end, v, w;
    for (int64_t g = begin; g < end; g++) {
      std::set<int64_t> n_id_set;
      n_id_set.insert(idx_data[g]);
      std::vector<int64_t> n_ids;
      n_ids.push_back(idx_data[g]);

      vec_start = 0, vec_end = n_ids.size();
      for (int64_t d = 0; d < depth; d++) {
        for (int64_t i = vec_start; i < vec_end; i++) {
          v = n_ids[i];
          row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
          row_count = row_end - row_start;

          if (row_count <= num_neighbors) {
            for (int64_t e = row_start; e < row_end; e++) {
              w = col_data[e];
              n_id_set.insert(w);
              n_ids.push_back(w);
            }
          } else if (replace) {
            for (int64_t j = 0; j < num_neighbors; j++) {
57
              w = col_data[row_start + uniform_randint(row_count)];
rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
              n_id_set.insert(w);
              n_ids.push_back(w);
            }
          } else {
            std::unordered_set<int64_t> perm;
            for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
64
              if (!perm.insert(uniform_randint(j)).second) {
rusty1s's avatar
rusty1s committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
                perm.insert(j);
              }
            }
            for (int64_t j : perm) {
              w = col_data[row_start + j];
              n_id_set.insert(w);
              n_ids.push_back(w);
            }
          }
        }
        vec_start = vec_end;
        vec_end = n_ids.size();
      }

      n_ids.clear();
rusty1s's avatar
update  
rusty1s committed
80
81
      std::map<int64_t, int64_t> n_id_map;
      std::map<int64_t, int64_t>::iterator iter;
rusty1s's avatar
rusty1s committed
82
83
84
85
86
87
88
89

      int64_t i = 0;
      for (int64_t v : n_id_set) {
        n_ids.push_back(v);
        n_id_map[v] = i;
        i++;
      }

rusty1s's avatar
rusty1s committed
90
91
      out_root_n_id_data[g] = n_id_map[idx_data[g]];

rusty1s's avatar
rusty1s committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
      std::vector<int64_t> rowptrs, cols, e_ids;
      for (int64_t v : n_ids) {
        row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
        for (int64_t e = row_start; e < row_end; e++) {
          w = col_data[e];
          iter = n_id_map.find(w);
          if (iter != n_id_map.end()) {
            cols.push_back(iter->second);
            e_ids.push_back(e);
          }
        }
        rowptrs.push_back(cols.size());
      }

      out_rowptrs[g + 1] = vec2tensor(rowptrs);
      out_cols[g] = vec2tensor(cols);
      out_n_ids[g] = vec2tensor(n_ids);
      out_e_ids[g] = vec2tensor(e_ids);
    }
  });

  auto out_ptr = torch::empty({idx.numel() + 1}, at::kLong);
  auto out_ptr_data = out_ptr.data_ptr<int64_t>();
  out_ptr_data[0] = 0;

  int64_t node_cumsum = 0, edge_cumsum = 0;
rusty1s's avatar
update  
rusty1s committed
118
119
120
121
122
123
  for (int64_t g = 1; g < idx.numel(); g++) {
    node_cumsum += out_n_ids[g - 1].numel();
    edge_cumsum += out_cols[g - 1].numel();
    out_rowptrs[g + 1].add_(edge_cumsum);
    out_cols[g].add_(node_cumsum);
    out_ptr_data[g] = node_cumsum;
rusty1s's avatar
rusty1s committed
124
    out_root_n_id_data[g] += node_cumsum;
rusty1s's avatar
rusty1s committed
125
126
127
128
129
130
  }
  node_cumsum += out_n_ids[idx.numel() - 1].numel();
  out_ptr_data[idx.numel()] = node_cumsum;

  return std::make_tuple(torch::cat(out_rowptrs, 0), torch::cat(out_cols, 0),
                         torch::cat(out_n_ids, 0), torch::cat(out_e_ids, 0),
rusty1s's avatar
rusty1s committed
131
                         out_ptr, out_root_n_id);
rusty1s's avatar
rusty1s committed
132
}