"vscode:/vscode.git/clone" did not exist on "be99201a567c1ccd841dc16fb24e88f7f239c187"
labor_pick.h 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
/*!
 *   Copyright (c) 2022, NVIDIA Corporation
 *   Copyright (c) 2022, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 *   All rights reserved.
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 *
 * \file array/cpu/labor_pick.h
 * \brief Template implementation for layerwise pick operators.
 */

#ifndef DGL_ARRAY_CPU_LABOR_PICK_H_
#define DGL_ARRAY_CPU_LABOR_PICK_H_

#include <dgl/array.h>
#include <dgl/random.h>
#include <dgl/runtime/parallel_for.h>
#include <dmlc/omp.h>
#include <parallel_hashmap/phmap.h>

#include <algorithm>
#include <cmath>
#include <functional>
#include <iostream>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include <pcg_random.hpp>

namespace dgl {
namespace aten {
namespace impl {

constexpr double eps = 0.0001;

template <typename IdxType, typename FloatType>
auto compute_importance_sampling_probabilities(
    DGLContext ctx, DGLDataType dtype, const IdxType max_degree,
    const IdxType num_rows, const int importance_sampling, const bool weighted,
    const IdxType* rows_data, const IdxType* indptr, const FloatType* A,
    const IdxType* indices, const IdxType num_picks, const FloatType* ds,
    FloatType* cs) {
  constexpr FloatType ONE = 1;
  // ps stands for \pi in arXiv:2210.13339
  FloatArray ps_array = NDArray::Empty({max_degree + 1}, dtype, ctx);
  FloatType* ps = ps_array.Ptr<FloatType>();

  double prev_ex_nodes = max_degree * num_rows;

  phmap::flat_hash_map<IdxType, FloatType> hop_map, hop_map2;
  for (int iters = 0; iters < importance_sampling || importance_sampling < 0;
       iters++) {
    // NOTE(mfbalin) When the graph is unweighted, the first c values in
    // the first iteration can be computed in O(1) as k / d where k is fanout
    // and d is the degree.

    // If the graph is weighted, the first c values are computed in the inner
    // for loop instead. Therefore the importance_sampling argument should be
    // increased by one in the caller.

    // The later iterations will have correct c values so the if block will be
    // executed.

    if (!weighted || iters) {
      hop_map2.clear();
      for (int64_t i = 0; i < num_rows; ++i) {
        const FloatType c = cs[i];
        const IdxType rid = rows_data[i];
        for (auto j = indptr[rid]; j < indptr[rid + 1]; j++) {
          const auto ct = c * (weighted && iters == 1 ? A[j] : 1);
          auto itb = hop_map2.emplace(indices[j], ct);
          if (!itb.second) itb.first->second = std::max(ct, itb.first->second);
        }
      }
      if (hop_map.empty())
        hop_map = std::move(hop_map2);
      else
        // Update the pi array according to Eq 18.
        for (auto it : hop_map2) hop_map[it.first] *= it.second;
    }

    // Compute c_s according to Equation (15), (17) is slower because sorting is
    // required.
    for (int64_t i = 0; i < num_rows; ++i) {
      const IdxType rid = rows_data[i];
      const auto d = indptr[rid + 1] - indptr[rid];
      if (d == 0) continue;

      const auto k = std::min(num_picks, d);

      if (hop_map.empty()) {  // weighted first iter, pi = A
        for (auto j = indptr[rid]; j < indptr[rid + 1]; j++)
          ps[j - indptr[rid]] = A[j];
      } else {
        for (auto j = indptr[rid]; j < indptr[rid + 1]; j++)
          ps[j - indptr[rid]] = hop_map[indices[j]];
      }

      // stands for RHS of Equation (22) in arXiv:2210.13339 after moving the
      // other terms without c_s to RHS.
      double var_target = ds[i] * ds[i] / k;
      if (weighted) {
        var_target -= ds[i] * ds[i] / d;
        for (auto j = indptr[rid]; j < indptr[rid + 1]; j++)
          var_target += A[j] * A[j];
      }
      FloatType c = cs[i];
      // stands for left handside of Equation (22) in arXiv:2210.13339 after
      // moving the other terms without c_s to RHS.
      double var_1;
      // Compute c_s in Equation (22) via fixed-point iteration.
      do {
        var_1 = 0;
        if (weighted) {
          for (auto j = indptr[rid]; j < indptr[rid + 1]; j++)
            var_1 += A[j] * A[j] / std::min(ONE, c * ps[j - indptr[rid]]);
        } else {
          for (auto j = indptr[rid]; j < indptr[rid + 1]; j++)
            var_1 += ONE / std::min(ONE, c * ps[j - indptr[rid]]);
        }

        c *= var_1 / var_target;
      } while (std::min(var_1, var_target) / std::max(var_1, var_target) <
               1 - eps);

      cs[i] = c;
    }

    // Check convergence
    if (!weighted || iters) {
      double cur_ex_nodes = 0;
      for (auto it : hop_map) cur_ex_nodes += std::min((FloatType)1, it.second);
      if (cur_ex_nodes / prev_ex_nodes >= 1 - eps) break;
      prev_ex_nodes = cur_ex_nodes;
    }
  }

  return hop_map;
}

// Template for picking non-zero values row-wise.
template <typename IdxType, typename FloatType>
std::pair<COOMatrix, FloatArray> CSRLaborPick(
    CSRMatrix mat, IdArray rows, int64_t num_picks, FloatArray prob,
    int importance_sampling, IdArray random_seed_arr, IdArray NIDs) {
  using namespace aten;
  const IdxType* indptr = mat.indptr.Ptr<IdxType>();
  const IdxType* indices = mat.indices.Ptr<IdxType>();
  const IdxType* data = CSRHasData(mat) ? mat.data.Ptr<IdxType>() : nullptr;
  const IdxType* rows_data = rows.Ptr<IdxType>();
  const IdxType* nids = IsNullArray(NIDs) ? nullptr : NIDs.Ptr<IdxType>();
  const auto num_rows = rows->shape[0];
  const auto& ctx = mat.indptr->ctx;

  const bool weighted = !IsNullArray(prob);
  // O(1) c computation not possible, so one more iteration is needed.
  if (importance_sampling >= 0) importance_sampling += weighted;
  // A stands for the same notation in arXiv:2210.13339, i.e. the edge weights.
  auto A_arr = prob;
  FloatType* A = A_arr.Ptr<FloatType>();
  constexpr FloatType ONE = 1;

  constexpr auto dtype = DGLDataTypeTraits<FloatType>::dtype;

  // cs stands for c_s in arXiv:2210.13339
  FloatArray cs_array = NDArray::Empty({num_rows}, dtype, ctx);
  FloatType* cs = cs_array.Ptr<FloatType>();
  // ds stands for A_{*s} in arXiv:2210.13339
  FloatArray ds_array = NDArray::Empty({num_rows}, dtype, ctx);
  FloatType* ds = ds_array.Ptr<FloatType>();

  IdxType max_degree = 1;
  IdxType hop_size = 0;
  for (int64_t i = 0; i < num_rows; ++i) {
    const IdxType rid = rows_data[i];
    const auto act_degree = indptr[rid + 1] - indptr[rid];
    max_degree = std::max(act_degree, max_degree);
    double d = weighted
                   ? std::accumulate(A + indptr[rid], A + indptr[rid + 1], 0.0)
                   : act_degree;
    // O(1) c computation, samples more than needed for weighted case, mentioned
    // in the sentence between (10) and (11) in arXiv:2210.13339
    cs[i] = num_picks / d;
    ds[i] = d;
    hop_size += act_degree;
  }

  phmap::flat_hash_map<IdxType, FloatType> hop_map;

  if (importance_sampling)
    hop_map = compute_importance_sampling_probabilities<IdxType, FloatType>(
        ctx, dtype, max_degree, num_rows, importance_sampling, weighted,
        rows_data, indptr, A, indices, (IdxType)num_picks, ds, cs);

  constexpr auto vidtype = DGLDataTypeTraits<IdxType>::dtype;

  IdArray picked_row = NDArray::Empty({hop_size}, vidtype, ctx);
  IdArray picked_col = NDArray::Empty({hop_size}, vidtype, ctx);
  IdArray picked_idx = NDArray::Empty({hop_size}, vidtype, ctx);
  FloatArray picked_imp =
      NDArray::Empty({importance_sampling ? hop_size : 0}, dtype, ctx);
  IdxType* picked_rdata = picked_row.Ptr<IdxType>();
  IdxType* picked_cdata = picked_col.Ptr<IdxType>();
  IdxType* picked_idata = picked_idx.Ptr<IdxType>();
  FloatType* picked_imp_data = picked_imp.Ptr<FloatType>();

  const uint64_t random_seed =
      IsNullArray(random_seed_arr)
          ? RandomEngine::ThreadLocal()->RandInt(1000000000)
          : random_seed_arr.Ptr<int64_t>()[0];

  // compute number of edges first and do sampling
  IdxType num_edges = 0;
  for (int64_t i = 0; i < num_rows; i++) {
    const IdxType rid = rows_data[i];
    const auto c = cs[i];

    FloatType norm_inv_p = 0;
    const auto off = num_edges;
    for (auto j = indptr[rid]; j < indptr[rid + 1]; j++) {
      const auto v = indices[j];
      const auto t = nids ? nids[v] : v;  // t in the paper
      pcg32 ng(random_seed, t);
      std::uniform_real_distribution<FloatType> uni;
      // rolled random number r_t is a function of the random_seed and t
      const auto rnd = uni(ng);
      const auto w = (weighted ? A[j] : 1);
      // if hop_map is initialized, get ps from there, otherwise get it from the
      // alternative.
      const auto ps = std::min(
          ONE, importance_sampling - weighted ? c * hop_map[v] : c * w);
      if (rnd <= ps) {
        picked_rdata[num_edges] = rid;
        picked_cdata[num_edges] = v;
        picked_idata[num_edges] = data ? data[j] : j;
        if (importance_sampling) {
          const auto edge_weight = w / ps;
          norm_inv_p += edge_weight;
          picked_imp_data[num_edges] = edge_weight;
        }
        num_edges++;
      }
    }

    if (importance_sampling) {
      const auto norm_factor = (num_edges - off) / norm_inv_p;
      for (auto i = off; i < num_edges; i++)
        // so that fn.mean can be used
        picked_imp_data[i] *= norm_factor;
    }
  }

  picked_row = picked_row.CreateView({num_edges}, picked_row->dtype);
  picked_col = picked_col.CreateView({num_edges}, picked_col->dtype);
  picked_idx = picked_idx.CreateView({num_edges}, picked_idx->dtype);
  if (importance_sampling)
    picked_imp = picked_imp.CreateView({num_edges}, picked_imp->dtype);

  return std::make_pair(
      COOMatrix(mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx),
      picked_imp);
}

// Template for picking non-zero values row-wise. The implementation first
// slices out the corresponding rows and then converts it to CSR format. It then
// performs row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType, typename FloatType>
std::pair<COOMatrix, FloatArray> COOLaborPick(
    COOMatrix mat, IdArray rows, int64_t num_picks, FloatArray prob,
    int importance_sampling, IdArray random_seed, IdArray NIDs) {
  using namespace aten;
  const auto& csr = COOToCSR(COOSliceRows(mat, rows));
  const IdArray new_rows =
      Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
  const auto&& picked_importances = CSRLaborPick<IdxType, FloatType>(
      csr, new_rows, num_picks, prob, importance_sampling, random_seed, NIDs);
  const auto& picked = picked_importances.first;
  const auto& importances = picked_importances.second;
  return std::make_pair(
      COOMatrix(
          mat.num_rows, mat.num_cols,
          IndexSelect(
              rows, picked.row),  // map the row index to the correct one
          picked.col, picked.data),
      importances);
}

}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_LABOR_PICK_H_