"vscode:/vscode.git/clone" did not exist on "0c0bb085e1cfbd8bc2c349a0975a777e4eaa8c36"
libra_partition.cc 24.2 KB
Newer Older
1
/**
2
3
4
5
6
7
8
9
10
11
12
 *  Copyright (c) 2021 Intel Corporation
 *
 *  @file distgnn/partition/main_Libra.py
 *  @brief Libra - Vertex-cut based graph partitioner for distirbuted training
 *  @author Vasimuddin Md <vasimuddin.md@intel.com>,
 *          Guixiang Ma <guixiang.ma@intel.com>
 *          Sanchit Misra <sanchit.misra@intel.com>,
 *          Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,
 *          Sasikanth Avancha <sasikanth.avancha@intel.com>
 *          Nesreen K. Ahmed <nesreen.k.ahmed@intel.com>
 */
13

14
15
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h>
16
#include <dgl/random.h>
17
#include <dgl/runtime/parallel_for.h>
18
#include <dmlc/omp.h>
19
20
#include <stdint.h>

21
22
23
24
25
26
27
28
#include <vector>

#ifdef USE_TVM
#include <featgraph.h>
#endif  // USE_TVM

#include "../c_api_common.h"
#include "./check.h"
29
#include "kernel_decl.h"
30
31
32
33
34
35

using namespace dgl::runtime;

namespace dgl {
namespace aten {

36
37
template <typename IdType>
int32_t Ver2partition(IdType in_val, int64_t *node_map, int32_t num_parts) {
38
  int32_t pos = 0;
39
40
  for (int32_t p = 0; p < num_parts; p++) {
    if (in_val < node_map[p]) return pos;
41
42
43
    pos = pos + 1;
  }
  LOG(FATAL) << "Error: Unexpected output in Ver2partition!";
44
  return -1;
45
46
}

47
/**
48
 * @brief Identifies the lead loaded partition/community for a given edge
49
50
51
 * assignment.
 */
int32_t LeastLoad(int64_t *community_edges, int32_t nc) {
52
53
  std::vector<int> loc;
  int32_t min = 1e9;
54
  for (int32_t i = 0; i < nc; i++) {
55
56
57
58
    if (community_edges[i] < min) {
      min = community_edges[i];
    }
  }
59
  for (int32_t i = 0; i < nc; i++) {
60
61
62
63
64
65
66
67
68
69
    if (community_edges[i] == min) {
      loc.push_back(i);
    }
  }

  int32_t r = RandomEngine::ThreadLocal()->RandInt(loc.size());
  CHECK(loc[r] < nc);
  return loc[r];
}

70
/**
71
 * @brief Libra - vertexcut based graph partitioning.
72
73
74
75
76
 * It takes list of edges from input DGL graph and distributed them among nc
 * partitions During edge distribution, Libra assign a given edge to a partition
 * based on the end vertices, in doing so, it tries to minimized the splitting
 * of the graph vertices. In case of conflict Libra assigns an edge to the least
 * loaded partition/community.
77
78
79
80
81
82
83
84
85
86
87
 * @param[in] nc Number of partitions/communities
 * @param[in] node_degree per node degree
 * @param[in] edgenum_unassigned node degree
 * @param[out] community_weights weight of the created partitions
 * @param[in] u src nodes
 * @param[in] v dst nodes
 * @param[out] w weight per edge
 * @param[out] out partition assignment of the edges
 * @param[in] N_n number of nodes in the input graph
 * @param[in] N_e number of edges in the input graph
 * @param[in] prefix output/partition storage location
88
89
 */
template <typename IdType, typename IdType2>
90
void LibraVertexCut(
91
92
93
94
95
    int32_t nc, NDArray node_degree, NDArray edgenum_unassigned,
    NDArray community_weights, NDArray u, NDArray v, NDArray w, NDArray out,
    int64_t N_n, int64_t N_e, const std::string &prefix) {
  int32_t *out_ptr = out.Ptr<int32_t>();
  IdType2 *node_degree_ptr = node_degree.Ptr<IdType2>();
96
  IdType2 *edgenum_unassigned_ptr = edgenum_unassigned.Ptr<IdType2>();
97
98
99
100
  IdType *u_ptr = u.Ptr<IdType>();
  IdType *v_ptr = v.Ptr<IdType>();
  int64_t *w_ptr = w.Ptr<int64_t>();
  int64_t *community_weights_ptr = community_weights.Ptr<int64_t>();
101
102
103
104
105
106
107

  std::vector<std::vector<int32_t> > node_assignments(N_n);
  std::vector<IdType2> replication_list;
  // local allocations
  int64_t *community_edges = new int64_t[nc]();
  int64_t *cache = new int64_t[nc]();

108
109
110
111
112
  int64_t meter = static_cast<int>(N_e / 100);
  for (int64_t i = 0; i < N_e; i++) {
    IdType u = u_ptr[i];   // edge end vertex 1
    IdType v = v_ptr[i];   // edge end vertex 2
    int64_t w = w_ptr[i];  // edge weight
113
114
115
116
117

    CHECK(u < N_n);
    CHECK(v < N_n);

    if (i % meter == 0) {
118
119
      fprintf(stderr, ".");
      fflush(0);
120
121
122
123
124
125
126
127
128
129
    }

    if (node_assignments[u].size() == 0 && node_assignments[v].size() == 0) {
      int32_t c = LeastLoad(community_edges, nc);
      out_ptr[i] = c;
      CHECK_LT(c, nc);

      community_edges[c]++;
      community_weights_ptr[c] = community_weights_ptr[c] + w;
      node_assignments[u].push_back(c);
130
      if (u != v) node_assignments[v].push_back(c);
131

132
133
134
135
      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
          << "[bug] 1. generated splits (u) are greater than nc!";
      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
          << "[bug] 1. generated splits (v) are greater than nc!";
136
137
      edgenum_unassigned_ptr[u]--;
      edgenum_unassigned_ptr[v]--;
138
139
140
    } else if (
        node_assignments[u].size() != 0 && node_assignments[v].size() == 0) {
      for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
141
142
143
144
145
146
147
148
149
150
        int32_t cind = node_assignments[u][j];
        cache[j] = community_edges[cind];
      }
      int32_t cindex = LeastLoad(cache, node_assignments[u].size());
      int32_t c = node_assignments[u][cindex];
      out_ptr[i] = c;
      community_edges[c]++;
      community_weights_ptr[c] = community_weights_ptr[c] + w;

      node_assignments[v].push_back(c);
151
152
      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
          << "[bug] 2. generated splits (v) are greater than nc!";
153
154
      edgenum_unassigned_ptr[u]--;
      edgenum_unassigned_ptr[v]--;
155
156
157
    } else if (
        node_assignments[v].size() != 0 && node_assignments[u].size() == 0) {
      for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
158
159
160
161
162
163
164
165
166
167
168
169
        int32_t cind = node_assignments[v][j];
        cache[j] = community_edges[cind];
      }
      int32_t cindex = LeastLoad(cache, node_assignments[v].size());
      int32_t c = node_assignments[v][cindex];
      CHECK(c < nc) << "[bug] 2. partition greater than nc !!";
      out_ptr[i] = c;

      community_edges[c]++;
      community_weights_ptr[c] = community_weights_ptr[c] + w;

      node_assignments[u].push_back(c);
170
171
      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
          << "[bug] 3. generated splits (u) are greater than nc!";
172
173
174
175
      edgenum_unassigned_ptr[u]--;
      edgenum_unassigned_ptr[v]--;
    } else {
      std::vector<int> setv(nc), intersetv;
176
      for (int32_t j = 0; j < nc; j++) setv[j] = 0;
177
178
      int32_t interset = 0;

179
180
181
182
183
184
185
      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
          << "[bug] 4. generated splits (u) are greater than nc!";
      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
          << "[bug] 4. generated splits (v) are greater than nc!";
      for (size_t j = 0; j < node_assignments[v].size(); j++) {
        CHECK(node_assignments[v][j] < nc)
            << "[bug] 4. Part assigned (v) greater than nc!";
186
187
188
        setv[node_assignments[v][j]]++;
      }

189
190
191
      for (size_t j = 0; j < node_assignments[u].size(); j++) {
        CHECK(node_assignments[u][j] < nc)
            << "[bug] 4. Part assigned (u) greater than nc!";
192
193
194
        setv[node_assignments[u][j]]++;
      }

195
      for (int32_t j = 0; j < nc; j++) {
196
197
198
199
200
201
202
        CHECK(setv[j] <= 2) << "[bug] 4. unexpected computed value !!!";
        if (setv[j] == 2) {
          interset++;
          intersetv.push_back(j);
        }
      }
      if (interset) {
203
        for (size_t j = 0; j < intersetv.size(); j++) {
204
205
206
207
208
209
210
211
212
213
214
215
216
          int32_t cind = intersetv[j];
          cache[j] = community_edges[cind];
        }
        int32_t cindex = LeastLoad(cache, intersetv.size());
        int32_t c = intersetv[cindex];
        CHECK(c < nc) << "[bug] 4. partition greater than nc !!";
        out_ptr[i] = c;
        community_edges[c]++;
        community_weights_ptr[c] = community_weights_ptr[c] + w;
        edgenum_unassigned_ptr[u]--;
        edgenum_unassigned_ptr[v]--;
      } else {
        if (node_degree_ptr[u] < node_degree_ptr[v]) {
217
          for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
218
219
220
221
222
223
224
225
226
227
            int32_t cind = node_assignments[u][j];
            cache[j] = community_edges[cind];
          }
          int32_t cindex = LeastLoad(cache, node_assignments[u].size());
          int32_t c = node_assignments[u][cindex];
          CHECK(c < nc) << "[bug] 5. partition greater than nc !!";
          out_ptr[i] = c;
          community_edges[c]++;
          community_weights_ptr[c] = community_weights_ptr[c] + w;

228
229
230
          for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
            CHECK(node_assignments[v][j] != c)
                << "[bug] 5. duplicate partition (v) assignment !!";
231
232
233
          }

          node_assignments[v].push_back(c);
234
235
          CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
              << "[bug] 5. generated splits (v) greater than nc!!";
236
237
238
239
          replication_list.push_back(v);
          edgenum_unassigned_ptr[u]--;
          edgenum_unassigned_ptr[v]--;
        } else {
240
          for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
241
242
243
244
245
            int32_t cind = node_assignments[v][j];
            cache[j] = community_edges[cind];
          }
          int32_t cindex = LeastLoad(cache, node_assignments[v].size());
          int32_t c = node_assignments[v][cindex];
246
          CHECK(c < nc) << "[bug] 6. partition greater than nc !!";
247
248
249
          out_ptr[i] = c;
          community_edges[c]++;
          community_weights_ptr[c] = community_weights_ptr[c] + w;
250
251
252
          for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
            CHECK(node_assignments[u][j] != c)
                << "[bug] 6. duplicate partition (u) assignment !!";
253
          }
254
          if (u != v) node_assignments[u].push_back(c);
255

256
257
          CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
              << "[bug] 6. generated splits (u) greater than nc!!";
258
259
260
261
262
263
264
265
266
          replication_list.push_back(u);
          edgenum_unassigned_ptr[u]--;
          edgenum_unassigned_ptr[v]--;
        }
      }
    }
  }
  delete cache;

267
268
  for (int64_t c = 0; c < nc; c++) {
    std::string path = prefix + "/community" + std::to_string(c) + ".txt";
269
270

    FILE *fp = fopen(path.c_str(), "w");
271
272
    CHECK_NE(fp, static_cast<FILE *>(NULL))
        << "Error: can not open file: " << path.c_str();
273

274
    for (int64_t i = 0; i < N_e; i++) {
275
      if (out_ptr[i] == c)
276
277
        fprintf(
            fp, "%ld,%ld,%ld\n", static_cast<int64_t>(u_ptr[i]),
278
            static_cast<int64_t>(v_ptr[i]), w_ptr[i]);
279
280
281
282
283
284
    }
    fclose(fp);
  }

  std::string path = prefix + "/replicationlist.csv";
  FILE *fp = fopen(path.c_str(), "w");
285
286
  CHECK_NE(fp, static_cast<FILE *>(NULL))
      << "Error: can not open file: " << path.c_str();
287
288
289
290

  fprintf(fp, "## The Indices of Nodes that are replicated :: Header");
  printf("\nTotal replication: %ld\n", replication_list.size());

291
  for (uint64_t i = 0; i < replication_list.size(); i++)
292
    fprintf(fp, "%ld\n", static_cast<int64_t>(replication_list[i]));
293
294

  printf("Community weights:\n");
295
  for (int64_t c = 0; c < nc; c++) printf("%ld ", community_weights_ptr[c]);
296
297
298
  printf("\n");

  printf("Community edges:\n");
299
  for (int64_t c = 0; c < nc; c++) printf("%ld ", community_edges[c]);
300
301
302
303
304
305
306
  printf("\n");

  delete community_edges;
  fclose(fp);
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibraVertexCut")
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      int32_t nc = args[0];
      NDArray node_degree = args[1];
      NDArray edgenum_unassigned = args[2];
      NDArray community_weights = args[3];
      NDArray u = args[4];
      NDArray v = args[5];
      NDArray w = args[6];
      NDArray out = args[7];
      int64_t N = args[8];
      int64_t N_e = args[9];
      std::string prefix = args[10];

      ATEN_ID_TYPE_SWITCH(node_degree->dtype, IdType2, {
        ATEN_ID_TYPE_SWITCH(u->dtype, IdType, {
          LibraVertexCut<IdType, IdType2>(
              nc, node_degree, edgenum_unassigned, community_weights, u, v, w,
              out, N, N_e, prefix);
325
        });
326
      });
327
    });
328

329
/**
330
 * @brief
331
332
333
334
335
336
 * 1. Builds dictionary (ldt) for assigning local node IDs to nodes in the
 *    partitions
 * 2. Builds dictionary (gdt) for storing copies (local ID) of split nodes
 *    These dictionaries will be used in the subsequesnt stages to setup
 *    tracking of split nodes copies across the partition, setting up partition
 *    `ndata` dictionaries.
337
338
339
 * @param[out] a local src node ID of an edge in a partition
 * @param[out] b local dst node ID of an edge in a partition
 * @param[-] indices temporary memory, keeps track of global node ID to local
340
 *           node ID in a partition
341
 * @param[out] ldt_key per partition dict for storing global and local node IDs
342
 *             (consecutive)
343
 * @param[out] gdt_key global dict for storing number of local nodes (or split
344
 *             nodes) for a given global node ID
345
 * @param[out] gdt_value global dict, stores local node IDs (due to split)
346
 *             across partitions for a given global node ID
347
 * @param[out] node_map keeps track of range of local node IDs (consecutive)
348
 *             given to the nodes in the partitions
349
 * @param[in, out] offset start of the range of local node IDs for this
350
 *                 partition
351
 * @param[in] nc number of partitions/communities
352
353
 * @param[in] c current partition number
 * @param[in] fsize size of pre-allocated
354
 *            memory tensor
355
 * @param[in] prefix input Libra partition file location
356
357
 */
List<Value> Libra2dglBuildDict(
358
359
360
361
362
363
364
365
366
367
    NDArray a, NDArray b, NDArray indices, NDArray ldt_key, NDArray gdt_key,
    NDArray gdt_value, NDArray node_map, NDArray offset, int32_t nc, int32_t c,
    int64_t fsize, const std::string &prefix) {
  int64_t *indices_ptr = indices.Ptr<int64_t>();  // 1D temp array
  int64_t *ldt_key_ptr =
      ldt_key.Ptr<int64_t>();  // 1D local nodes <-> global nodes
  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();  // 1D #split copies per node
  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor
  int64_t *node_map_ptr = node_map.Ptr<int64_t>();    // 1D tensor
  int64_t *offset_ptr = offset.Ptr<int64_t>();        // 1D tensor
368
369
  int32_t width = nc;

370
371
  int64_t *a_ptr = a.Ptr<int64_t>();  // stores local src and dst node ID,
  int64_t *b_ptr = b.Ptr<int64_t>();  // to create the partition graph
372
373
374
375

  int64_t N_n = indices->shape[0];
  int64_t num_nodes = ldt_key->shape[0];

376
  for (int64_t i = 0; i < N_n; i++) {
377
378
379
380
381
382
383
    indices_ptr[i] = -100;
  }

  int64_t pos = 0;
  int64_t edge = 0;
  std::string path = prefix + "/community" + std::to_string(c) + ".txt";
  FILE *fp = fopen(path.c_str(), "r");
384
385
  CHECK_NE(fp, static_cast<FILE *>(NULL))
      << "Error: can not open file: " << path.c_str();
386
387
388
389

  while (!feof(fp) && edge < fsize) {
    int64_t u, v;
    float w;
390
391
392
393
394
395
396
397
398
399
    fscanf(
        fp, "%ld,%ld,%f\n", &u, &v,
        &w);  // reading an edge - the src and dst global node IDs

    if (indices_ptr[u] ==
        -100) {  // if already not assigned a local node ID, local node ID is
      ldt_key_ptr[pos] = u;    // already assigned for this global node ID
      CHECK(pos < num_nodes);  // Sanity check
      indices_ptr[u] =
          pos++;  // consecutive local node ID for a given global node ID
400
    }
401
    if (indices_ptr[v] == -100) {  // if already not assigned a local node ID
402
      ldt_key_ptr[pos] = v;
403
      CHECK(pos < num_nodes);  // Sanity check
404
405
      indices_ptr[v] = pos++;
    }
406
407
    a_ptr[edge] = indices_ptr[u];    // new local ID for an edge
    b_ptr[edge++] = indices_ptr[v];  // new local ID for an edge
408
  }
409
410
  CHECK(edge <= fsize)
      << "[Bug] memory allocated for #edges per partition is not enough.";
411
412
413
  fclose(fp);

  List<Value> ret;
414
415
416
417
  ret.push_back(Value(
      MakeValue(pos)));  // returns total number of nodes in this partition
  ret.push_back(Value(
      MakeValue(edge)));  // returns total number of edges in this partition
418

419
420
  for (int64_t i = 0; i < pos; i++) {
    int64_t u = ldt_key_ptr[i];  // global node ID
421
    // int64_t  v   = indices_ptr[u];
422
423
424
425
426
427
428
    int64_t v = i;  // local node ID
    int64_t *ind =
        &gdt_key_ptr[u];  // global dict, total number of local node IDs (an
                          // offset) as of now for a given global node ID
    int64_t *ptr = gdt_value_ptr + u * width;
    ptr[*ind] =
        offset_ptr[0] + v;  // stores a local node ID for the global node ID
429
430
431
432
    (*ind)++;
    CHECK_NE(v, -100);
    CHECK(*ind <= nc);
  }
433
434
435
436
  node_map_ptr[c] =
      offset_ptr[0] +
      pos;  // since local node IDs for a partition are consecutive,
            // we maintain the range of local node IDs like this
437
438
439
440
441
442
  offset_ptr[0] += pos;

  return ret;
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildDict")
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      NDArray a = args[0];
      NDArray b = args[1];
      NDArray indices = args[2];
      NDArray ldt_key = args[3];
      NDArray gdt_key = args[4];
      NDArray gdt_value = args[5];
      NDArray node_map = args[6];
      NDArray offset = args[7];
      int32_t nc = args[8];
      int32_t c = args[9];
      int64_t fsize = args[10];
      std::string prefix = args[11];
      List<Value> ret = Libra2dglBuildDict(
          a, b, indices, ldt_key, gdt_key, gdt_value, node_map, offset, nc, c,
          fsize, prefix);
      *rv = ret;
    });

462
/**
463
464
 * @brief sets up the 1-level tree among the clones of the split-nodes.
 * @param[in] gdt_key global dict for assigning consecutive node IDs to nodes
465
 *            across all the partitions
466
 * @param[in] gdt_value global dict for assigning consecutive node IDs to nodes
467
 *            across all the partition
468
469
470
 * @param[out] lrtensor keeps the root node ID of 1-level tree
 * @param[in] nc number of partitions/communities
 * @param[in] Nn number of nodes in the input graph
471
472
 */
void Libra2dglSetLR(
473
474
475
476
477
    NDArray gdt_key, NDArray gdt_value, NDArray lrtensor, int32_t nc,
    int64_t Nn) {
  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();      // 1D tensor
  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor
  int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>();    // 1D tensor
478
479
480
481
482

  int32_t width = nc;
  int64_t cnt = 0;
  int64_t avg_split_copy = 0, scnt = 0;

483
  for (int64_t i = 0; i < Nn; i++) {
484
485
486
487
488
489
490
    if (gdt_key_ptr[i] <= 0) {
      cnt++;
    } else {
      int32_t val = RandomEngine::ThreadLocal()->RandInt(gdt_key_ptr[i]);
      CHECK(val >= 0 && val < gdt_key_ptr[i]);
      CHECK(gdt_key_ptr[i] <= nc);

491
      int64_t *ptr = gdt_value_ptr + i * width;
492
493
494
495
496
497
498
499
500
501
      lrtensor_ptr[i] = ptr[val];
    }
    if (gdt_key_ptr[i] > 1) {
      avg_split_copy += gdt_key_ptr[i];
      scnt++;
    }
  }
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglSetLR")
502
503
504
505
506
507
508
509
510
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      NDArray gdt_key = args[0];
      NDArray gdt_value = args[1];
      NDArray lrtensor = args[2];
      int32_t nc = args[3];
      int64_t Nn = args[4];

      Libra2dglSetLR(gdt_key, gdt_value, lrtensor, nc, Nn);
    });
511

512
/**
513
 * @brief For each node in a partition, it creates a list of remote clone IDs;
514
515
 *        also, for each node in a partition, it gathers the data (feats, label,
 *        trian, test) from input graph.
516
517
518
519
520
521
 * @param[out] feat node features in current partition c.
 * @param[in] gfeat input graph node features.
 * @param[out] adj list of node IDs of remote clones.
 * @param[out] inner_nodes marks whether a node is split or not.
 * @param[in] ldt_key per partition dict for tracking global to local node IDs
 * @param[out] gdt_key global dict for storing number of local nodes (or split
522
523
 *             nodes) for a given global node ID
 * @param[out] gdt_value global
524
525
 *             dict, stores local node IDs (due to split) across partitions for
 *             a given global node ID.
526
 * @param[in] node_map keeps track of range of local node IDs (consecutive)
527
 *            given to the nodes in the partitions.
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
 * @param[out] lr 1-level tree marking for local split nodes.
 * @param[in] lrtensor global (all the partitions) 1-level tree.
 * @param[in] num_nodes number of nodes in current partition.
 * @param[in] nc number of partitions/communities.
 * @param[in] c current partition/community.
 * @param[in] feat_size node feature vector size.
 * @param[out] labels local (for this partition) labels.
 * @param[out] trainm local (for this partition) training nodes.
 * @param[out] testm local (for this partition) testing nodes.
 * @param[out] valm local (for this partition) validation nodes.
 * @param[in] glabels global (input graph) labels.
 * @param[in] gtrainm glabal (input graph) training nodes.
 * @param[in] gtestm glabal (input graph) testing nodes.
 * @param[in] gvalm glabal (input graph) validation nodes.
 * @param[out] Nn number of nodes in the input graph.
543
 */
544
template <typename IdType, typename IdType2, typename DType>
545
void Libra2dglBuildAdjlist(
546
547
548
549
550
551
552
553
554
    NDArray feat, NDArray gfeat, NDArray adj, NDArray inner_node,
    NDArray ldt_key, NDArray gdt_key, NDArray gdt_value, NDArray node_map,
    NDArray lr, NDArray lrtensor, int64_t num_nodes, int32_t nc, int32_t c,
    int32_t feat_size, NDArray labels, NDArray trainm, NDArray testm,
    NDArray valm, NDArray glabels, NDArray gtrainm, NDArray gtestm,
    NDArray gvalm, int64_t Nn) {
  DType *feat_ptr = feat.Ptr<DType>();    // 2D tensor
  DType *gfeat_ptr = gfeat.Ptr<DType>();  // 2D tensor
  int64_t *adj_ptr = adj.Ptr<int64_t>();  // 2D tensor
555
  int32_t *inner_node_ptr = inner_node.Ptr<int32_t>();
556
557
558
559
560
561
  int64_t *ldt_key_ptr = ldt_key.Ptr<int64_t>();
  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();
  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor
  int64_t *node_map_ptr = node_map.Ptr<int64_t>();
  int64_t *lr_ptr = lr.Ptr<int64_t>();
  int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>();
562
563
  int32_t width = nc - 1;

564
565
  runtime::parallel_for(0, num_nodes, [&](int64_t s, int64_t e) {
    for (int64_t i = s; i < e; i++) {
566
567
568
569
      int64_t k = ldt_key_ptr[i];
      int64_t v = i;
      int64_t ind = gdt_key_ptr[k];

570
      int64_t *adj_ptr_ptr = adj_ptr + v * width;
571
      if (ind == 1) {
572
        for (int32_t j = 0; j < width; j++) adj_ptr_ptr[j] = -1;
573
574
575
576
        inner_node_ptr[i] = 1;
        lr_ptr[i] = -200;
      } else {
        lr_ptr[i] = lrtensor_ptr[k];
577
        int64_t *ptr = gdt_value_ptr + k * nc;
578
579
580
        int64_t pos = 0;
        CHECK(ind <= nc);
        int32_t flg = 0;
581
        for (int64_t j = 0; j < ind; j++) {
582
          if (ptr[j] == lr_ptr[i]) flg = 1;
583
          if (c != Ver2partition<int64_t>(ptr[j], node_map_ptr, nc))
584
585
586
587
588
589
590
591
592
593
594
            adj_ptr_ptr[pos++] = ptr[j];
        }
        CHECK_EQ(flg, 1);
        CHECK(pos == ind - 1);
        for (; pos < width; pos++) adj_ptr_ptr[pos] = -1;
        inner_node_ptr[i] = 0;
      }
    }
  });

  // gather
595
596
  runtime::parallel_for(0, num_nodes, [&](int64_t s, int64_t e) {
    for (int64_t i = s; i < e; i++) {
597
      int64_t k = ldt_key_ptr[i];
598
      int64_t ind = i * feat_size;
599
      DType *optr = gfeat_ptr + ind;
600
      DType *iptr = feat_ptr + k * feat_size;
601

602
      for (int32_t j = 0; j < feat_size; j++) optr[j] = iptr[j];
603
604
605
606
607
608
609
610
611
612
613
    }

    IdType *labels_ptr = labels.Ptr<IdType>();
    IdType *glabels_ptr = glabels.Ptr<IdType>();
    IdType2 *trainm_ptr = trainm.Ptr<IdType2>();
    IdType2 *gtrainm_ptr = gtrainm.Ptr<IdType2>();
    IdType2 *testm_ptr = testm.Ptr<IdType2>();
    IdType2 *gtestm_ptr = gtestm.Ptr<IdType2>();
    IdType2 *valm_ptr = valm.Ptr<IdType2>();
    IdType2 *gvalm_ptr = gvalm.Ptr<IdType2>();

614
    for (int64_t i = 0; i < num_nodes; i++) {
615
      int64_t k = ldt_key_ptr[i];
616
      CHECK(k >= 0 && k < Nn);
617
618
619
620
621
622
623
624
625
      glabels_ptr[i] = labels_ptr[k];
      gtrainm_ptr[i] = trainm_ptr[k];
      gtestm_ptr[i] = testm_ptr[k];
      gvalm_ptr[i] = valm_ptr[k];
    }
  });
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildAdjlist")
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      NDArray feat = args[0];
      NDArray gfeat = args[1];
      NDArray adj = args[2];
      NDArray inner_node = args[3];
      NDArray ldt_key = args[4];
      NDArray gdt_key = args[5];
      NDArray gdt_value = args[6];
      NDArray node_map = args[7];
      NDArray lr = args[8];
      NDArray lrtensor = args[9];
      int64_t num_nodes = args[10];
      int32_t nc = args[11];
      int32_t c = args[12];
      int32_t feat_size = args[13];
      NDArray labels = args[14];
      NDArray trainm = args[15];
      NDArray testm = args[16];
      NDArray valm = args[17];
      NDArray glabels = args[18];
      NDArray gtrainm = args[19];
      NDArray gtestm = args[20];
      NDArray gvalm = args[21];
      int64_t Nn = args[22];

      ATEN_FLOAT_TYPE_SWITCH(feat->dtype, DType, "Features", {
        ATEN_ID_TYPE_SWITCH(trainm->dtype, IdType2, {
653
          ATEN_ID_BITS_SWITCH((glabels->dtype).bits, IdType, {
654
655
656
657
658
            Libra2dglBuildAdjlist<IdType, IdType2, DType>(
                feat, gfeat, adj, inner_node, ldt_key, gdt_key, gdt_value,
                node_map, lr, lrtensor, num_nodes, nc, c, feat_size, labels,
                trainm, testm, valm, glabels, gtrainm, gtestm, gvalm, Nn);
          });
659
        });
660
      });
661
662
663
664
    });

}  // namespace aten
}  // namespace dgl