libra_partition.cc 24.2 KB
Newer Older
1
/**
2
3
4
5
6
7
8
9
10
11
12
Copyright (c) 2021 Intel Corporation
 \file distgnn/partition/main_Libra.py
 \brief Libra - Vertex-cut based graph partitioner for distirbuted training
 \author Vasimuddin Md <vasimuddin.md@intel.com>,
         Guixiang Ma <guixiang.ma@intel.com>
         Sanchit Misra <sanchit.misra@intel.com>,
         Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,
         Sasikanth Avancha <sasikanth.avancha@intel.com>
         Nesreen K. Ahmed <nesreen.k.ahmed@intel.com>
*/

13
14
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h>
15
#include <dgl/random.h>
16
#include <dgl/runtime/parallel_for.h>
17
#include <dmlc/omp.h>
18
19
#include <stdint.h>

20
21
22
23
24
25
26
27
#include <vector>

#ifdef USE_TVM
#include <featgraph.h>
#endif  // USE_TVM

#include "../c_api_common.h"
#include "./check.h"
28
#include "kernel_decl.h"
29
30
31
32
33
34

using namespace dgl::runtime;

namespace dgl {
namespace aten {

35
36
template <typename IdType>
int32_t Ver2partition(IdType in_val, int64_t *node_map, int32_t num_parts) {
37
  int32_t pos = 0;
38
39
  for (int32_t p = 0; p < num_parts; p++) {
    if (in_val < node_map[p]) return pos;
40
41
42
43
44
    pos = pos + 1;
  }
  LOG(FATAL) << "Error: Unexpected output in Ver2partition!";
}

45
/**
46
 * @brief Identifies the lead loaded partition/community for a given edge
47
48
49
 * assignment.
 */
int32_t LeastLoad(int64_t *community_edges, int32_t nc) {
50
51
  std::vector<int> loc;
  int32_t min = 1e9;
52
  for (int32_t i = 0; i < nc; i++) {
53
54
55
56
    if (community_edges[i] < min) {
      min = community_edges[i];
    }
  }
57
  for (int32_t i = 0; i < nc; i++) {
58
59
60
61
62
63
64
65
66
67
    if (community_edges[i] == min) {
      loc.push_back(i);
    }
  }

  int32_t r = RandomEngine::ThreadLocal()->RandInt(loc.size());
  CHECK(loc[r] < nc);
  return loc[r];
}

68
/**
69
 * @brief Libra - vertexcut based graph partitioning.
70
71
72
73
74
 * It takes list of edges from input DGL graph and distributed them among nc
 * partitions During edge distribution, Libra assign a given edge to a partition
 * based on the end vertices, in doing so, it tries to minimized the splitting
 * of the graph vertices. In case of conflict Libra assigns an edge to the least
 * loaded partition/community.
75
76
77
78
79
80
81
82
83
84
85
 * @param[in] nc Number of partitions/communities
 * @param[in] node_degree per node degree
 * @param[in] edgenum_unassigned node degree
 * @param[out] community_weights weight of the created partitions
 * @param[in] u src nodes
 * @param[in] v dst nodes
 * @param[out] w weight per edge
 * @param[out] out partition assignment of the edges
 * @param[in] N_n number of nodes in the input graph
 * @param[in] N_e number of edges in the input graph
 * @param[in] prefix output/partition storage location
86
87
 */
template <typename IdType, typename IdType2>
88
void LibraVertexCut(
89
90
91
92
93
    int32_t nc, NDArray node_degree, NDArray edgenum_unassigned,
    NDArray community_weights, NDArray u, NDArray v, NDArray w, NDArray out,
    int64_t N_n, int64_t N_e, const std::string &prefix) {
  int32_t *out_ptr = out.Ptr<int32_t>();
  IdType2 *node_degree_ptr = node_degree.Ptr<IdType2>();
94
  IdType2 *edgenum_unassigned_ptr = edgenum_unassigned.Ptr<IdType2>();
95
96
97
98
  IdType *u_ptr = u.Ptr<IdType>();
  IdType *v_ptr = v.Ptr<IdType>();
  int64_t *w_ptr = w.Ptr<int64_t>();
  int64_t *community_weights_ptr = community_weights.Ptr<int64_t>();
99
100
101
102
103
104
105

  std::vector<std::vector<int32_t> > node_assignments(N_n);
  std::vector<IdType2> replication_list;
  // local allocations
  int64_t *community_edges = new int64_t[nc]();
  int64_t *cache = new int64_t[nc]();

106
107
108
109
110
  int64_t meter = static_cast<int>(N_e / 100);
  for (int64_t i = 0; i < N_e; i++) {
    IdType u = u_ptr[i];   // edge end vertex 1
    IdType v = v_ptr[i];   // edge end vertex 2
    int64_t w = w_ptr[i];  // edge weight
111
112
113
114
115

    CHECK(u < N_n);
    CHECK(v < N_n);

    if (i % meter == 0) {
116
117
      fprintf(stderr, ".");
      fflush(0);
118
119
120
121
122
123
124
125
126
127
    }

    if (node_assignments[u].size() == 0 && node_assignments[v].size() == 0) {
      int32_t c = LeastLoad(community_edges, nc);
      out_ptr[i] = c;
      CHECK_LT(c, nc);

      community_edges[c]++;
      community_weights_ptr[c] = community_weights_ptr[c] + w;
      node_assignments[u].push_back(c);
128
      if (u != v) node_assignments[v].push_back(c);
129

130
131
132
133
      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
          << "[bug] 1. generated splits (u) are greater than nc!";
      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
          << "[bug] 1. generated splits (v) are greater than nc!";
134
135
      edgenum_unassigned_ptr[u]--;
      edgenum_unassigned_ptr[v]--;
136
137
138
    } else if (
        node_assignments[u].size() != 0 && node_assignments[v].size() == 0) {
      for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
139
140
141
142
143
144
145
146
147
148
        int32_t cind = node_assignments[u][j];
        cache[j] = community_edges[cind];
      }
      int32_t cindex = LeastLoad(cache, node_assignments[u].size());
      int32_t c = node_assignments[u][cindex];
      out_ptr[i] = c;
      community_edges[c]++;
      community_weights_ptr[c] = community_weights_ptr[c] + w;

      node_assignments[v].push_back(c);
149
150
      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
          << "[bug] 2. generated splits (v) are greater than nc!";
151
152
      edgenum_unassigned_ptr[u]--;
      edgenum_unassigned_ptr[v]--;
153
154
155
    } else if (
        node_assignments[v].size() != 0 && node_assignments[u].size() == 0) {
      for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
156
157
158
159
160
161
162
163
164
165
166
167
        int32_t cind = node_assignments[v][j];
        cache[j] = community_edges[cind];
      }
      int32_t cindex = LeastLoad(cache, node_assignments[v].size());
      int32_t c = node_assignments[v][cindex];
      CHECK(c < nc) << "[bug] 2. partition greater than nc !!";
      out_ptr[i] = c;

      community_edges[c]++;
      community_weights_ptr[c] = community_weights_ptr[c] + w;

      node_assignments[u].push_back(c);
168
169
      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
          << "[bug] 3. generated splits (u) are greater than nc!";
170
171
172
173
      edgenum_unassigned_ptr[u]--;
      edgenum_unassigned_ptr[v]--;
    } else {
      std::vector<int> setv(nc), intersetv;
174
      for (int32_t j = 0; j < nc; j++) setv[j] = 0;
175
176
      int32_t interset = 0;

177
178
179
180
181
182
183
      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
          << "[bug] 4. generated splits (u) are greater than nc!";
      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
          << "[bug] 4. generated splits (v) are greater than nc!";
      for (size_t j = 0; j < node_assignments[v].size(); j++) {
        CHECK(node_assignments[v][j] < nc)
            << "[bug] 4. Part assigned (v) greater than nc!";
184
185
186
        setv[node_assignments[v][j]]++;
      }

187
188
189
      for (size_t j = 0; j < node_assignments[u].size(); j++) {
        CHECK(node_assignments[u][j] < nc)
            << "[bug] 4. Part assigned (u) greater than nc!";
190
191
192
        setv[node_assignments[u][j]]++;
      }

193
      for (int32_t j = 0; j < nc; j++) {
194
195
196
197
198
199
200
        CHECK(setv[j] <= 2) << "[bug] 4. unexpected computed value !!!";
        if (setv[j] == 2) {
          interset++;
          intersetv.push_back(j);
        }
      }
      if (interset) {
201
        for (size_t j = 0; j < intersetv.size(); j++) {
202
203
204
205
206
207
208
209
210
211
212
213
214
          int32_t cind = intersetv[j];
          cache[j] = community_edges[cind];
        }
        int32_t cindex = LeastLoad(cache, intersetv.size());
        int32_t c = intersetv[cindex];
        CHECK(c < nc) << "[bug] 4. partition greater than nc !!";
        out_ptr[i] = c;
        community_edges[c]++;
        community_weights_ptr[c] = community_weights_ptr[c] + w;
        edgenum_unassigned_ptr[u]--;
        edgenum_unassigned_ptr[v]--;
      } else {
        if (node_degree_ptr[u] < node_degree_ptr[v]) {
215
          for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
216
217
218
219
220
221
222
223
224
225
            int32_t cind = node_assignments[u][j];
            cache[j] = community_edges[cind];
          }
          int32_t cindex = LeastLoad(cache, node_assignments[u].size());
          int32_t c = node_assignments[u][cindex];
          CHECK(c < nc) << "[bug] 5. partition greater than nc !!";
          out_ptr[i] = c;
          community_edges[c]++;
          community_weights_ptr[c] = community_weights_ptr[c] + w;

226
227
228
          for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
            CHECK(node_assignments[v][j] != c)
                << "[bug] 5. duplicate partition (v) assignment !!";
229
230
231
          }

          node_assignments[v].push_back(c);
232
233
          CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
              << "[bug] 5. generated splits (v) greater than nc!!";
234
235
236
237
          replication_list.push_back(v);
          edgenum_unassigned_ptr[u]--;
          edgenum_unassigned_ptr[v]--;
        } else {
238
          for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
239
240
241
242
243
            int32_t cind = node_assignments[v][j];
            cache[j] = community_edges[cind];
          }
          int32_t cindex = LeastLoad(cache, node_assignments[v].size());
          int32_t c = node_assignments[v][cindex];
244
          CHECK(c < nc) << "[bug] 6. partition greater than nc !!";
245
246
247
          out_ptr[i] = c;
          community_edges[c]++;
          community_weights_ptr[c] = community_weights_ptr[c] + w;
248
249
250
          for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
            CHECK(node_assignments[u][j] != c)
                << "[bug] 6. duplicate partition (u) assignment !!";
251
          }
252
          if (u != v) node_assignments[u].push_back(c);
253

254
255
          CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
              << "[bug] 6. generated splits (u) greater than nc!!";
256
257
258
259
260
261
262
263
264
          replication_list.push_back(u);
          edgenum_unassigned_ptr[u]--;
          edgenum_unassigned_ptr[v]--;
        }
      }
    }
  }
  delete cache;

265
266
  for (int64_t c = 0; c < nc; c++) {
    std::string path = prefix + "/community" + std::to_string(c) + ".txt";
267
268

    FILE *fp = fopen(path.c_str(), "w");
269
270
    CHECK_NE(fp, static_cast<FILE *>(NULL))
        << "Error: can not open file: " << path.c_str();
271

272
    for (int64_t i = 0; i < N_e; i++) {
273
      if (out_ptr[i] == c)
274
275
        fprintf(
            fp, "%ld,%ld,%ld\n", static_cast<int64_t>(u_ptr[i]),
276
            static_cast<int64_t>(v_ptr[i]), w_ptr[i]);
277
278
279
280
281
282
    }
    fclose(fp);
  }

  std::string path = prefix + "/replicationlist.csv";
  FILE *fp = fopen(path.c_str(), "w");
283
284
  CHECK_NE(fp, static_cast<FILE *>(NULL))
      << "Error: can not open file: " << path.c_str();
285
286
287
288

  fprintf(fp, "## The Indices of Nodes that are replicated :: Header");
  printf("\nTotal replication: %ld\n", replication_list.size());

289
  for (uint64_t i = 0; i < replication_list.size(); i++)
290
    fprintf(fp, "%ld\n", static_cast<int64_t>(replication_list[i]));
291
292

  printf("Community weights:\n");
293
  for (int64_t c = 0; c < nc; c++) printf("%ld ", community_weights_ptr[c]);
294
295
296
  printf("\n");

  printf("Community edges:\n");
297
  for (int64_t c = 0; c < nc; c++) printf("%ld ", community_edges[c]);
298
299
300
301
302
303
304
  printf("\n");

  delete community_edges;
  fclose(fp);
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibraVertexCut")
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      int32_t nc = args[0];
      NDArray node_degree = args[1];
      NDArray edgenum_unassigned = args[2];
      NDArray community_weights = args[3];
      NDArray u = args[4];
      NDArray v = args[5];
      NDArray w = args[6];
      NDArray out = args[7];
      int64_t N = args[8];
      int64_t N_e = args[9];
      std::string prefix = args[10];

      ATEN_ID_TYPE_SWITCH(node_degree->dtype, IdType2, {
        ATEN_ID_TYPE_SWITCH(u->dtype, IdType, {
          LibraVertexCut<IdType, IdType2>(
              nc, node_degree, edgenum_unassigned, community_weights, u, v, w,
              out, N, N_e, prefix);
323
        });
324
      });
325
    });
326

327
/**
328
 * @brief
329
330
331
332
333
334
 * 1. Builds dictionary (ldt) for assigning local node IDs to nodes in the
 *    partitions
 * 2. Builds dictionary (gdt) for storing copies (local ID) of split nodes
 *    These dictionaries will be used in the subsequesnt stages to setup
 *    tracking of split nodes copies across the partition, setting up partition
 *    `ndata` dictionaries.
335
336
337
 * @param[out] a local src node ID of an edge in a partition
 * @param[out] b local dst node ID of an edge in a partition
 * @param[-] indices temporary memory, keeps track of global node ID to local
338
 *           node ID in a partition
339
 * @param[out] ldt_key per partition dict for storing global and local node IDs
340
 *             (consecutive)
341
 * @param[out] gdt_key global dict for storing number of local nodes (or split
342
 *             nodes) for a given global node ID
343
 * @param[out] gdt_value global dict, stores local node IDs (due to split)
344
 *             across partitions for a given global node ID
345
 * @param[out] node_map keeps track of range of local node IDs (consecutive)
346
 *             given to the nodes in the partitions
347
 * @param[in, out] offset start of the range of local node IDs for this
348
 *                 partition
349
350
 * @param[in] nc number of partitions/communities
 * @param[in] c current partition number \param[in] fsize size of pre-allocated
351
 *            memory tensor
352
 * @param[in] prefix input Libra partition file location
353
354
 */
List<Value> Libra2dglBuildDict(
355
356
357
358
359
360
361
362
363
364
    NDArray a, NDArray b, NDArray indices, NDArray ldt_key, NDArray gdt_key,
    NDArray gdt_value, NDArray node_map, NDArray offset, int32_t nc, int32_t c,
    int64_t fsize, const std::string &prefix) {
  int64_t *indices_ptr = indices.Ptr<int64_t>();  // 1D temp array
  int64_t *ldt_key_ptr =
      ldt_key.Ptr<int64_t>();  // 1D local nodes <-> global nodes
  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();  // 1D #split copies per node
  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor
  int64_t *node_map_ptr = node_map.Ptr<int64_t>();    // 1D tensor
  int64_t *offset_ptr = offset.Ptr<int64_t>();        // 1D tensor
365
366
  int32_t width = nc;

367
368
  int64_t *a_ptr = a.Ptr<int64_t>();  // stores local src and dst node ID,
  int64_t *b_ptr = b.Ptr<int64_t>();  // to create the partition graph
369
370
371
372

  int64_t N_n = indices->shape[0];
  int64_t num_nodes = ldt_key->shape[0];

373
  for (int64_t i = 0; i < N_n; i++) {
374
375
376
377
378
379
380
    indices_ptr[i] = -100;
  }

  int64_t pos = 0;
  int64_t edge = 0;
  std::string path = prefix + "/community" + std::to_string(c) + ".txt";
  FILE *fp = fopen(path.c_str(), "r");
381
382
  CHECK_NE(fp, static_cast<FILE *>(NULL))
      << "Error: can not open file: " << path.c_str();
383
384
385
386

  while (!feof(fp) && edge < fsize) {
    int64_t u, v;
    float w;
387
388
389
390
391
392
393
394
395
396
    fscanf(
        fp, "%ld,%ld,%f\n", &u, &v,
        &w);  // reading an edge - the src and dst global node IDs

    if (indices_ptr[u] ==
        -100) {  // if already not assigned a local node ID, local node ID is
      ldt_key_ptr[pos] = u;    // already assigned for this global node ID
      CHECK(pos < num_nodes);  // Sanity check
      indices_ptr[u] =
          pos++;  // consecutive local node ID for a given global node ID
397
    }
398
    if (indices_ptr[v] == -100) {  // if already not assigned a local node ID
399
      ldt_key_ptr[pos] = v;
400
      CHECK(pos < num_nodes);  // Sanity check
401
402
      indices_ptr[v] = pos++;
    }
403
404
    a_ptr[edge] = indices_ptr[u];    // new local ID for an edge
    b_ptr[edge++] = indices_ptr[v];  // new local ID for an edge
405
  }
406
407
  CHECK(edge <= fsize)
      << "[Bug] memory allocated for #edges per partition is not enough.";
408
409
410
  fclose(fp);

  List<Value> ret;
411
412
413
414
  ret.push_back(Value(
      MakeValue(pos)));  // returns total number of nodes in this partition
  ret.push_back(Value(
      MakeValue(edge)));  // returns total number of edges in this partition
415

416
417
  for (int64_t i = 0; i < pos; i++) {
    int64_t u = ldt_key_ptr[i];  // global node ID
418
    // int64_t  v   = indices_ptr[u];
419
420
421
422
423
424
425
    int64_t v = i;  // local node ID
    int64_t *ind =
        &gdt_key_ptr[u];  // global dict, total number of local node IDs (an
                          // offset) as of now for a given global node ID
    int64_t *ptr = gdt_value_ptr + u * width;
    ptr[*ind] =
        offset_ptr[0] + v;  // stores a local node ID for the global node ID
426
427
428
429
    (*ind)++;
    CHECK_NE(v, -100);
    CHECK(*ind <= nc);
  }
430
431
432
433
  node_map_ptr[c] =
      offset_ptr[0] +
      pos;  // since local node IDs for a partition are consecutive,
            // we maintain the range of local node IDs like this
434
435
436
437
438
439
  offset_ptr[0] += pos;

  return ret;
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildDict")
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      NDArray a = args[0];
      NDArray b = args[1];
      NDArray indices = args[2];
      NDArray ldt_key = args[3];
      NDArray gdt_key = args[4];
      NDArray gdt_value = args[5];
      NDArray node_map = args[6];
      NDArray offset = args[7];
      int32_t nc = args[8];
      int32_t c = args[9];
      int64_t fsize = args[10];
      std::string prefix = args[11];
      List<Value> ret = Libra2dglBuildDict(
          a, b, indices, ldt_key, gdt_key, gdt_value, node_map, offset, nc, c,
          fsize, prefix);
      *rv = ret;
    });

459
/**
460
461
 * @brief sets up the 1-level tree among the clones of the split-nodes.
 * @param[in] gdt_key global dict for assigning consecutive node IDs to nodes
462
 *            across all the partitions
463
 * @param[in] gdt_value global dict for assigning consecutive node IDs to nodes
464
 *            across all the partition
465
466
467
 * @param[out] lrtensor keeps the root node ID of 1-level tree
 * @param[in] nc number of partitions/communities
 * @param[in] Nn number of nodes in the input graph
468
469
 */
void Libra2dglSetLR(
470
471
472
473
474
    NDArray gdt_key, NDArray gdt_value, NDArray lrtensor, int32_t nc,
    int64_t Nn) {
  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();      // 1D tensor
  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor
  int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>();    // 1D tensor
475
476
477
478
479

  int32_t width = nc;
  int64_t cnt = 0;
  int64_t avg_split_copy = 0, scnt = 0;

480
  for (int64_t i = 0; i < Nn; i++) {
481
482
483
484
485
486
487
    if (gdt_key_ptr[i] <= 0) {
      cnt++;
    } else {
      int32_t val = RandomEngine::ThreadLocal()->RandInt(gdt_key_ptr[i]);
      CHECK(val >= 0 && val < gdt_key_ptr[i]);
      CHECK(gdt_key_ptr[i] <= nc);

488
      int64_t *ptr = gdt_value_ptr + i * width;
489
490
491
492
493
494
495
496
497
498
      lrtensor_ptr[i] = ptr[val];
    }
    if (gdt_key_ptr[i] > 1) {
      avg_split_copy += gdt_key_ptr[i];
      scnt++;
    }
  }
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglSetLR")
499
500
501
502
503
504
505
506
507
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      NDArray gdt_key = args[0];
      NDArray gdt_value = args[1];
      NDArray lrtensor = args[2];
      int32_t nc = args[3];
      int64_t Nn = args[4];

      Libra2dglSetLR(gdt_key, gdt_value, lrtensor, nc, Nn);
    });
508

509
/**
510
 * @brief For each node in a partition, it creates a list of remote clone IDs;
511
512
 *        also, for each node in a partition, it gathers the data (feats, label,
 *        trian, test) from input graph.
513
514
515
516
517
518
 * @param[out] feat node features in current partition c.
 * @param[in] gfeat input graph node features.
 * @param[out] adj list of node IDs of remote clones.
 * @param[out] inner_nodes marks whether a node is split or not.
 * @param[in] ldt_key per partition dict for tracking global to local node IDs
 * @param[out] gdt_key global dict for storing number of local nodes (or split
519
520
521
 *             nodes) for a given global node ID \param[out] gdt_value global
 *             dict, stores local node IDs (due to split) across partitions for
 *             a given global node ID.
522
 * @param[in] node_map keeps track of range of local node IDs (consecutive)
523
 *            given to the nodes in the partitions.
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
 * @param[out] lr 1-level tree marking for local split nodes.
 * @param[in] lrtensor global (all the partitions) 1-level tree.
 * @param[in] num_nodes number of nodes in current partition.
 * @param[in] nc number of partitions/communities.
 * @param[in] c current partition/community.
 * @param[in] feat_size node feature vector size.
 * @param[out] labels local (for this partition) labels.
 * @param[out] trainm local (for this partition) training nodes.
 * @param[out] testm local (for this partition) testing nodes.
 * @param[out] valm local (for this partition) validation nodes.
 * @param[in] glabels global (input graph) labels.
 * @param[in] gtrainm glabal (input graph) training nodes.
 * @param[in] gtestm glabal (input graph) testing nodes.
 * @param[in] gvalm glabal (input graph) validation nodes.
 * @param[out] Nn number of nodes in the input graph.
539
 */
540
template <typename IdType, typename IdType2, typename DType>
541
void Libra2dglBuildAdjlist(
542
543
544
545
546
547
548
549
550
    NDArray feat, NDArray gfeat, NDArray adj, NDArray inner_node,
    NDArray ldt_key, NDArray gdt_key, NDArray gdt_value, NDArray node_map,
    NDArray lr, NDArray lrtensor, int64_t num_nodes, int32_t nc, int32_t c,
    int32_t feat_size, NDArray labels, NDArray trainm, NDArray testm,
    NDArray valm, NDArray glabels, NDArray gtrainm, NDArray gtestm,
    NDArray gvalm, int64_t Nn) {
  DType *feat_ptr = feat.Ptr<DType>();    // 2D tensor
  DType *gfeat_ptr = gfeat.Ptr<DType>();  // 2D tensor
  int64_t *adj_ptr = adj.Ptr<int64_t>();  // 2D tensor
551
  int32_t *inner_node_ptr = inner_node.Ptr<int32_t>();
552
553
554
555
556
557
  int64_t *ldt_key_ptr = ldt_key.Ptr<int64_t>();
  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();
  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor
  int64_t *node_map_ptr = node_map.Ptr<int64_t>();
  int64_t *lr_ptr = lr.Ptr<int64_t>();
  int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>();
558
559
  int32_t width = nc - 1;

560
561
  runtime::parallel_for(0, num_nodes, [&](int64_t s, int64_t e) {
    for (int64_t i = s; i < e; i++) {
562
563
564
565
      int64_t k = ldt_key_ptr[i];
      int64_t v = i;
      int64_t ind = gdt_key_ptr[k];

566
      int64_t *adj_ptr_ptr = adj_ptr + v * width;
567
      if (ind == 1) {
568
        for (int32_t j = 0; j < width; j++) adj_ptr_ptr[j] = -1;
569
570
571
572
        inner_node_ptr[i] = 1;
        lr_ptr[i] = -200;
      } else {
        lr_ptr[i] = lrtensor_ptr[k];
573
        int64_t *ptr = gdt_value_ptr + k * nc;
574
575
576
        int64_t pos = 0;
        CHECK(ind <= nc);
        int32_t flg = 0;
577
        for (int64_t j = 0; j < ind; j++) {
578
          if (ptr[j] == lr_ptr[i]) flg = 1;
579
          if (c != Ver2partition<int64_t>(ptr[j], node_map_ptr, nc))
580
581
582
583
584
585
586
587
588
589
590
            adj_ptr_ptr[pos++] = ptr[j];
        }
        CHECK_EQ(flg, 1);
        CHECK(pos == ind - 1);
        for (; pos < width; pos++) adj_ptr_ptr[pos] = -1;
        inner_node_ptr[i] = 0;
      }
    }
  });

  // gather
591
592
  runtime::parallel_for(0, num_nodes, [&](int64_t s, int64_t e) {
    for (int64_t i = s; i < e; i++) {
593
      int64_t k = ldt_key_ptr[i];
594
      int64_t ind = i * feat_size;
595
      DType *optr = gfeat_ptr + ind;
596
      DType *iptr = feat_ptr + k * feat_size;
597

598
      for (int32_t j = 0; j < feat_size; j++) optr[j] = iptr[j];
599
600
601
602
603
604
605
606
607
608
609
    }

    IdType *labels_ptr = labels.Ptr<IdType>();
    IdType *glabels_ptr = glabels.Ptr<IdType>();
    IdType2 *trainm_ptr = trainm.Ptr<IdType2>();
    IdType2 *gtrainm_ptr = gtrainm.Ptr<IdType2>();
    IdType2 *testm_ptr = testm.Ptr<IdType2>();
    IdType2 *gtestm_ptr = gtestm.Ptr<IdType2>();
    IdType2 *valm_ptr = valm.Ptr<IdType2>();
    IdType2 *gvalm_ptr = gvalm.Ptr<IdType2>();

610
    for (int64_t i = 0; i < num_nodes; i++) {
611
      int64_t k = ldt_key_ptr[i];
612
      CHECK(k >= 0 && k < Nn);
613
614
615
616
617
618
619
620
621
      glabels_ptr[i] = labels_ptr[k];
      gtrainm_ptr[i] = trainm_ptr[k];
      gtestm_ptr[i] = testm_ptr[k];
      gvalm_ptr[i] = valm_ptr[k];
    }
  });
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildAdjlist")
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      NDArray feat = args[0];
      NDArray gfeat = args[1];
      NDArray adj = args[2];
      NDArray inner_node = args[3];
      NDArray ldt_key = args[4];
      NDArray gdt_key = args[5];
      NDArray gdt_value = args[6];
      NDArray node_map = args[7];
      NDArray lr = args[8];
      NDArray lrtensor = args[9];
      int64_t num_nodes = args[10];
      int32_t nc = args[11];
      int32_t c = args[12];
      int32_t feat_size = args[13];
      NDArray labels = args[14];
      NDArray trainm = args[15];
      NDArray testm = args[16];
      NDArray valm = args[17];
      NDArray glabels = args[18];
      NDArray gtrainm = args[19];
      NDArray gtestm = args[20];
      NDArray gvalm = args[21];
      int64_t Nn = args[22];

      ATEN_FLOAT_TYPE_SWITCH(feat->dtype, DType, "Features", {
        ATEN_ID_TYPE_SWITCH(trainm->dtype, IdType2, {
649
          ATEN_ID_BITS_SWITCH((glabels->dtype).bits, IdType, {
650
651
652
653
654
            Libra2dglBuildAdjlist<IdType, IdType2, DType>(
                feat, gfeat, adj, inner_node, ldt_key, gdt_key, gdt_value,
                node_map, lr, lrtensor, num_nodes, nc, c, feat_size, labels,
                trainm, testm, valm, glabels, gtrainm, gtestm, gvalm, Nn);
          });
655
        });
656
      });
657
658
659
660
    });

}  // namespace aten
}  // namespace dgl