libra_partition.cc 24.2 KB
Newer Older
1
/**
2
3
4
5
6
7
8
9
10
11
12
 *  Copyright (c) 2021 Intel Corporation
 *
 *  @file distgnn/partition/main_Libra.py
 *  @brief Libra - Vertex-cut based graph partitioner for distirbuted training
 *  @author Vasimuddin Md <vasimuddin.md@intel.com>,
 *          Guixiang Ma <guixiang.ma@intel.com>
 *          Sanchit Misra <sanchit.misra@intel.com>,
 *          Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,
 *          Sasikanth Avancha <sasikanth.avancha@intel.com>
 *          Nesreen K. Ahmed <nesreen.k.ahmed@intel.com>
 */
13

14
15
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h>
16
#include <dgl/random.h>
17
#include <dgl/runtime/parallel_for.h>
18
#include <dmlc/omp.h>
19
20
#include <stdint.h>

21
22
23
24
#include <vector>

#include "../c_api_common.h"
#include "./check.h"
25
#include "kernel_decl.h"
26
27
28
29
30
31

using namespace dgl::runtime;

namespace dgl {
namespace aten {

32
33
template <typename IdType>
int32_t Ver2partition(IdType in_val, int64_t *node_map, int32_t num_parts) {
34
  int32_t pos = 0;
35
36
  for (int32_t p = 0; p < num_parts; p++) {
    if (in_val < node_map[p]) return pos;
37
38
39
    pos = pos + 1;
  }
  LOG(FATAL) << "Error: Unexpected output in Ver2partition!";
40
  return -1;
41
42
}

43
/**
44
 * @brief Identifies the lead loaded partition/community for a given edge
45
46
47
 * assignment.
 */
int32_t LeastLoad(int64_t *community_edges, int32_t nc) {
48
49
  std::vector<int> loc;
  int32_t min = 1e9;
50
  for (int32_t i = 0; i < nc; i++) {
51
52
53
54
    if (community_edges[i] < min) {
      min = community_edges[i];
    }
  }
55
  for (int32_t i = 0; i < nc; i++) {
56
57
58
59
60
61
62
63
64
65
    if (community_edges[i] == min) {
      loc.push_back(i);
    }
  }

  int32_t r = RandomEngine::ThreadLocal()->RandInt(loc.size());
  CHECK(loc[r] < nc);
  return loc[r];
}

66
/**
67
 * @brief Libra - vertexcut based graph partitioning.
68
69
70
71
72
 * It takes list of edges from input DGL graph and distributed them among nc
 * partitions During edge distribution, Libra assign a given edge to a partition
 * based on the end vertices, in doing so, it tries to minimized the splitting
 * of the graph vertices. In case of conflict Libra assigns an edge to the least
 * loaded partition/community.
73
74
75
76
77
78
79
80
81
82
83
 * @param[in] nc Number of partitions/communities
 * @param[in] node_degree per node degree
 * @param[in] edgenum_unassigned node degree
 * @param[out] community_weights weight of the created partitions
 * @param[in] u src nodes
 * @param[in] v dst nodes
 * @param[out] w weight per edge
 * @param[out] out partition assignment of the edges
 * @param[in] N_n number of nodes in the input graph
 * @param[in] N_e number of edges in the input graph
 * @param[in] prefix output/partition storage location
84
85
 */
template <typename IdType, typename IdType2>
86
void LibraVertexCut(
87
88
89
90
91
    int32_t nc, NDArray node_degree, NDArray edgenum_unassigned,
    NDArray community_weights, NDArray u, NDArray v, NDArray w, NDArray out,
    int64_t N_n, int64_t N_e, const std::string &prefix) {
  int32_t *out_ptr = out.Ptr<int32_t>();
  IdType2 *node_degree_ptr = node_degree.Ptr<IdType2>();
92
  IdType2 *edgenum_unassigned_ptr = edgenum_unassigned.Ptr<IdType2>();
93
94
95
96
  IdType *u_ptr = u.Ptr<IdType>();
  IdType *v_ptr = v.Ptr<IdType>();
  int64_t *w_ptr = w.Ptr<int64_t>();
  int64_t *community_weights_ptr = community_weights.Ptr<int64_t>();
97
98
99
100
101
102
103

  std::vector<std::vector<int32_t> > node_assignments(N_n);
  std::vector<IdType2> replication_list;
  // local allocations
  int64_t *community_edges = new int64_t[nc]();
  int64_t *cache = new int64_t[nc]();

104
105
106
107
108
  int64_t meter = static_cast<int>(N_e / 100);
  for (int64_t i = 0; i < N_e; i++) {
    IdType u = u_ptr[i];   // edge end vertex 1
    IdType v = v_ptr[i];   // edge end vertex 2
    int64_t w = w_ptr[i];  // edge weight
109
110
111
112
113

    CHECK(u < N_n);
    CHECK(v < N_n);

    if (i % meter == 0) {
114
115
      fprintf(stderr, ".");
      fflush(0);
116
117
118
119
120
121
122
123
124
125
    }

    if (node_assignments[u].size() == 0 && node_assignments[v].size() == 0) {
      int32_t c = LeastLoad(community_edges, nc);
      out_ptr[i] = c;
      CHECK_LT(c, nc);

      community_edges[c]++;
      community_weights_ptr[c] = community_weights_ptr[c] + w;
      node_assignments[u].push_back(c);
126
      if (u != v) node_assignments[v].push_back(c);
127

128
129
130
131
      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
          << "[bug] 1. generated splits (u) are greater than nc!";
      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
          << "[bug] 1. generated splits (v) are greater than nc!";
132
133
      edgenum_unassigned_ptr[u]--;
      edgenum_unassigned_ptr[v]--;
134
135
136
    } else if (
        node_assignments[u].size() != 0 && node_assignments[v].size() == 0) {
      for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
137
138
139
140
141
142
143
144
145
146
        int32_t cind = node_assignments[u][j];
        cache[j] = community_edges[cind];
      }
      int32_t cindex = LeastLoad(cache, node_assignments[u].size());
      int32_t c = node_assignments[u][cindex];
      out_ptr[i] = c;
      community_edges[c]++;
      community_weights_ptr[c] = community_weights_ptr[c] + w;

      node_assignments[v].push_back(c);
147
148
      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
          << "[bug] 2. generated splits (v) are greater than nc!";
149
150
      edgenum_unassigned_ptr[u]--;
      edgenum_unassigned_ptr[v]--;
151
152
153
    } else if (
        node_assignments[v].size() != 0 && node_assignments[u].size() == 0) {
      for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
154
155
156
157
158
159
160
161
162
163
164
165
        int32_t cind = node_assignments[v][j];
        cache[j] = community_edges[cind];
      }
      int32_t cindex = LeastLoad(cache, node_assignments[v].size());
      int32_t c = node_assignments[v][cindex];
      CHECK(c < nc) << "[bug] 2. partition greater than nc !!";
      out_ptr[i] = c;

      community_edges[c]++;
      community_weights_ptr[c] = community_weights_ptr[c] + w;

      node_assignments[u].push_back(c);
166
167
      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
          << "[bug] 3. generated splits (u) are greater than nc!";
168
169
170
171
      edgenum_unassigned_ptr[u]--;
      edgenum_unassigned_ptr[v]--;
    } else {
      std::vector<int> setv(nc), intersetv;
172
      for (int32_t j = 0; j < nc; j++) setv[j] = 0;
173
174
      int32_t interset = 0;

175
176
177
178
179
180
181
      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
          << "[bug] 4. generated splits (u) are greater than nc!";
      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
          << "[bug] 4. generated splits (v) are greater than nc!";
      for (size_t j = 0; j < node_assignments[v].size(); j++) {
        CHECK(node_assignments[v][j] < nc)
            << "[bug] 4. Part assigned (v) greater than nc!";
182
183
184
        setv[node_assignments[v][j]]++;
      }

185
186
187
      for (size_t j = 0; j < node_assignments[u].size(); j++) {
        CHECK(node_assignments[u][j] < nc)
            << "[bug] 4. Part assigned (u) greater than nc!";
188
189
190
        setv[node_assignments[u][j]]++;
      }

191
      for (int32_t j = 0; j < nc; j++) {
192
193
194
195
196
197
198
        CHECK(setv[j] <= 2) << "[bug] 4. unexpected computed value !!!";
        if (setv[j] == 2) {
          interset++;
          intersetv.push_back(j);
        }
      }
      if (interset) {
199
        for (size_t j = 0; j < intersetv.size(); j++) {
200
201
202
203
204
205
206
207
208
209
210
211
212
          int32_t cind = intersetv[j];
          cache[j] = community_edges[cind];
        }
        int32_t cindex = LeastLoad(cache, intersetv.size());
        int32_t c = intersetv[cindex];
        CHECK(c < nc) << "[bug] 4. partition greater than nc !!";
        out_ptr[i] = c;
        community_edges[c]++;
        community_weights_ptr[c] = community_weights_ptr[c] + w;
        edgenum_unassigned_ptr[u]--;
        edgenum_unassigned_ptr[v]--;
      } else {
        if (node_degree_ptr[u] < node_degree_ptr[v]) {
213
          for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
214
215
216
217
218
219
220
221
222
223
            int32_t cind = node_assignments[u][j];
            cache[j] = community_edges[cind];
          }
          int32_t cindex = LeastLoad(cache, node_assignments[u].size());
          int32_t c = node_assignments[u][cindex];
          CHECK(c < nc) << "[bug] 5. partition greater than nc !!";
          out_ptr[i] = c;
          community_edges[c]++;
          community_weights_ptr[c] = community_weights_ptr[c] + w;

224
225
226
          for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
            CHECK(node_assignments[v][j] != c)
                << "[bug] 5. duplicate partition (v) assignment !!";
227
228
229
          }

          node_assignments[v].push_back(c);
230
231
          CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))
              << "[bug] 5. generated splits (v) greater than nc!!";
232
233
234
235
          replication_list.push_back(v);
          edgenum_unassigned_ptr[u]--;
          edgenum_unassigned_ptr[v]--;
        } else {
236
          for (uint32_t j = 0; j < node_assignments[v].size(); j++) {
237
238
239
240
241
            int32_t cind = node_assignments[v][j];
            cache[j] = community_edges[cind];
          }
          int32_t cindex = LeastLoad(cache, node_assignments[v].size());
          int32_t c = node_assignments[v][cindex];
242
          CHECK(c < nc) << "[bug] 6. partition greater than nc !!";
243
244
245
          out_ptr[i] = c;
          community_edges[c]++;
          community_weights_ptr[c] = community_weights_ptr[c] + w;
246
247
248
          for (uint32_t j = 0; j < node_assignments[u].size(); j++) {
            CHECK(node_assignments[u][j] != c)
                << "[bug] 6. duplicate partition (u) assignment !!";
249
          }
250
          if (u != v) node_assignments[u].push_back(c);
251

252
253
          CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))
              << "[bug] 6. generated splits (u) greater than nc!!";
254
255
256
257
258
259
260
261
262
          replication_list.push_back(u);
          edgenum_unassigned_ptr[u]--;
          edgenum_unassigned_ptr[v]--;
        }
      }
    }
  }
  delete cache;

263
264
  for (int64_t c = 0; c < nc; c++) {
    std::string path = prefix + "/community" + std::to_string(c) + ".txt";
265
266

    FILE *fp = fopen(path.c_str(), "w");
267
268
    CHECK_NE(fp, static_cast<FILE *>(NULL))
        << "Error: can not open file: " << path.c_str();
269

270
    for (int64_t i = 0; i < N_e; i++) {
271
      if (out_ptr[i] == c)
272
273
        fprintf(
            fp, "%ld,%ld,%ld\n", static_cast<int64_t>(u_ptr[i]),
274
            static_cast<int64_t>(v_ptr[i]), w_ptr[i]);
275
276
277
278
279
280
    }
    fclose(fp);
  }

  std::string path = prefix + "/replicationlist.csv";
  FILE *fp = fopen(path.c_str(), "w");
281
282
  CHECK_NE(fp, static_cast<FILE *>(NULL))
      << "Error: can not open file: " << path.c_str();
283
284
285
286

  fprintf(fp, "## The Indices of Nodes that are replicated :: Header");
  printf("\nTotal replication: %ld\n", replication_list.size());

287
  for (uint64_t i = 0; i < replication_list.size(); i++)
288
    fprintf(fp, "%ld\n", static_cast<int64_t>(replication_list[i]));
289
290

  printf("Community weights:\n");
291
  for (int64_t c = 0; c < nc; c++) printf("%ld ", community_weights_ptr[c]);
292
293
294
  printf("\n");

  printf("Community edges:\n");
295
  for (int64_t c = 0; c < nc; c++) printf("%ld ", community_edges[c]);
296
297
298
299
300
301
302
  printf("\n");

  delete community_edges;
  fclose(fp);
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibraVertexCut")
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      int32_t nc = args[0];
      NDArray node_degree = args[1];
      NDArray edgenum_unassigned = args[2];
      NDArray community_weights = args[3];
      NDArray u = args[4];
      NDArray v = args[5];
      NDArray w = args[6];
      NDArray out = args[7];
      int64_t N = args[8];
      int64_t N_e = args[9];
      std::string prefix = args[10];

      ATEN_ID_TYPE_SWITCH(node_degree->dtype, IdType2, {
        ATEN_ID_TYPE_SWITCH(u->dtype, IdType, {
          LibraVertexCut<IdType, IdType2>(
              nc, node_degree, edgenum_unassigned, community_weights, u, v, w,
              out, N, N_e, prefix);
321
        });
322
      });
323
    });
324

325
/**
326
 * @brief
327
328
329
330
331
332
 * 1. Builds dictionary (ldt) for assigning local node IDs to nodes in the
 *    partitions
 * 2. Builds dictionary (gdt) for storing copies (local ID) of split nodes
 *    These dictionaries will be used in the subsequesnt stages to setup
 *    tracking of split nodes copies across the partition, setting up partition
 *    `ndata` dictionaries.
333
334
335
 * @param[out] a local src node ID of an edge in a partition
 * @param[out] b local dst node ID of an edge in a partition
 * @param[-] indices temporary memory, keeps track of global node ID to local
336
 *           node ID in a partition
337
 * @param[out] ldt_key per partition dict for storing global and local node IDs
338
 *             (consecutive)
339
 * @param[out] gdt_key global dict for storing number of local nodes (or split
340
 *             nodes) for a given global node ID
341
 * @param[out] gdt_value global dict, stores local node IDs (due to split)
342
 *             across partitions for a given global node ID
343
 * @param[out] node_map keeps track of range of local node IDs (consecutive)
344
 *             given to the nodes in the partitions
345
 * @param[in, out] offset start of the range of local node IDs for this
346
 *                 partition
347
 * @param[in] nc number of partitions/communities
348
349
 * @param[in] c current partition number
 * @param[in] fsize size of pre-allocated
350
 *            memory tensor
351
 * @param[in] prefix input Libra partition file location
352
353
 */
List<Value> Libra2dglBuildDict(
354
355
356
357
358
359
360
361
362
363
    NDArray a, NDArray b, NDArray indices, NDArray ldt_key, NDArray gdt_key,
    NDArray gdt_value, NDArray node_map, NDArray offset, int32_t nc, int32_t c,
    int64_t fsize, const std::string &prefix) {
  int64_t *indices_ptr = indices.Ptr<int64_t>();  // 1D temp array
  int64_t *ldt_key_ptr =
      ldt_key.Ptr<int64_t>();  // 1D local nodes <-> global nodes
  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();  // 1D #split copies per node
  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor
  int64_t *node_map_ptr = node_map.Ptr<int64_t>();    // 1D tensor
  int64_t *offset_ptr = offset.Ptr<int64_t>();        // 1D tensor
364
365
  int32_t width = nc;

366
367
  int64_t *a_ptr = a.Ptr<int64_t>();  // stores local src and dst node ID,
  int64_t *b_ptr = b.Ptr<int64_t>();  // to create the partition graph
368
369
370
371

  int64_t N_n = indices->shape[0];
  int64_t num_nodes = ldt_key->shape[0];

372
  for (int64_t i = 0; i < N_n; i++) {
373
374
375
376
377
378
379
    indices_ptr[i] = -100;
  }

  int64_t pos = 0;
  int64_t edge = 0;
  std::string path = prefix + "/community" + std::to_string(c) + ".txt";
  FILE *fp = fopen(path.c_str(), "r");
380
381
  CHECK_NE(fp, static_cast<FILE *>(NULL))
      << "Error: can not open file: " << path.c_str();
382
383
384
385

  while (!feof(fp) && edge < fsize) {
    int64_t u, v;
    float w;
386
387
388
    CHECK_EQ(
        fscanf(fp, "%ld,%ld,%f\n", &u, &v, &w),
        3);  // reading an edge - the src and dst global node IDs
389
390
391
392
393
394
395

    if (indices_ptr[u] ==
        -100) {  // if already not assigned a local node ID, local node ID is
      ldt_key_ptr[pos] = u;    // already assigned for this global node ID
      CHECK(pos < num_nodes);  // Sanity check
      indices_ptr[u] =
          pos++;  // consecutive local node ID for a given global node ID
396
    }
397
    if (indices_ptr[v] == -100) {  // if already not assigned a local node ID
398
      ldt_key_ptr[pos] = v;
399
      CHECK(pos < num_nodes);  // Sanity check
400
401
      indices_ptr[v] = pos++;
    }
402
403
    a_ptr[edge] = indices_ptr[u];    // new local ID for an edge
    b_ptr[edge++] = indices_ptr[v];  // new local ID for an edge
404
  }
405
406
  CHECK(edge <= fsize)
      << "[Bug] memory allocated for #edges per partition is not enough.";
407
408
409
  fclose(fp);

  List<Value> ret;
410
411
412
413
  ret.push_back(Value(
      MakeValue(pos)));  // returns total number of nodes in this partition
  ret.push_back(Value(
      MakeValue(edge)));  // returns total number of edges in this partition
414

415
416
  for (int64_t i = 0; i < pos; i++) {
    int64_t u = ldt_key_ptr[i];  // global node ID
417
    // int64_t  v   = indices_ptr[u];
418
419
420
421
422
423
424
    int64_t v = i;  // local node ID
    int64_t *ind =
        &gdt_key_ptr[u];  // global dict, total number of local node IDs (an
                          // offset) as of now for a given global node ID
    int64_t *ptr = gdt_value_ptr + u * width;
    ptr[*ind] =
        offset_ptr[0] + v;  // stores a local node ID for the global node ID
425
426
427
428
    (*ind)++;
    CHECK_NE(v, -100);
    CHECK(*ind <= nc);
  }
429
430
431
432
  node_map_ptr[c] =
      offset_ptr[0] +
      pos;  // since local node IDs for a partition are consecutive,
            // we maintain the range of local node IDs like this
433
434
435
436
437
438
  offset_ptr[0] += pos;

  return ret;
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildDict")
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      NDArray a = args[0];
      NDArray b = args[1];
      NDArray indices = args[2];
      NDArray ldt_key = args[3];
      NDArray gdt_key = args[4];
      NDArray gdt_value = args[5];
      NDArray node_map = args[6];
      NDArray offset = args[7];
      int32_t nc = args[8];
      int32_t c = args[9];
      int64_t fsize = args[10];
      std::string prefix = args[11];
      List<Value> ret = Libra2dglBuildDict(
          a, b, indices, ldt_key, gdt_key, gdt_value, node_map, offset, nc, c,
          fsize, prefix);
      *rv = ret;
    });

458
/**
459
460
 * @brief sets up the 1-level tree among the clones of the split-nodes.
 * @param[in] gdt_key global dict for assigning consecutive node IDs to nodes
461
 *            across all the partitions
462
 * @param[in] gdt_value global dict for assigning consecutive node IDs to nodes
463
 *            across all the partition
464
465
466
 * @param[out] lrtensor keeps the root node ID of 1-level tree
 * @param[in] nc number of partitions/communities
 * @param[in] Nn number of nodes in the input graph
467
468
 */
void Libra2dglSetLR(
469
470
471
472
473
    NDArray gdt_key, NDArray gdt_value, NDArray lrtensor, int32_t nc,
    int64_t Nn) {
  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();      // 1D tensor
  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor
  int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>();    // 1D tensor
474
475
476
477
478

  int32_t width = nc;
  int64_t cnt = 0;
  int64_t avg_split_copy = 0, scnt = 0;

479
  for (int64_t i = 0; i < Nn; i++) {
480
481
482
483
484
485
486
    if (gdt_key_ptr[i] <= 0) {
      cnt++;
    } else {
      int32_t val = RandomEngine::ThreadLocal()->RandInt(gdt_key_ptr[i]);
      CHECK(val >= 0 && val < gdt_key_ptr[i]);
      CHECK(gdt_key_ptr[i] <= nc);

487
      int64_t *ptr = gdt_value_ptr + i * width;
488
489
490
491
492
493
494
495
496
497
      lrtensor_ptr[i] = ptr[val];
    }
    if (gdt_key_ptr[i] > 1) {
      avg_split_copy += gdt_key_ptr[i];
      scnt++;
    }
  }
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglSetLR")
498
499
500
501
502
503
504
505
506
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      NDArray gdt_key = args[0];
      NDArray gdt_value = args[1];
      NDArray lrtensor = args[2];
      int32_t nc = args[3];
      int64_t Nn = args[4];

      Libra2dglSetLR(gdt_key, gdt_value, lrtensor, nc, Nn);
    });
507

508
/**
509
 * @brief For each node in a partition, it creates a list of remote clone IDs;
510
511
 *        also, for each node in a partition, it gathers the data (feats, label,
 *        trian, test) from input graph.
512
513
514
515
516
517
 * @param[out] feat node features in current partition c.
 * @param[in] gfeat input graph node features.
 * @param[out] adj list of node IDs of remote clones.
 * @param[out] inner_nodes marks whether a node is split or not.
 * @param[in] ldt_key per partition dict for tracking global to local node IDs
 * @param[out] gdt_key global dict for storing number of local nodes (or split
518
519
 *             nodes) for a given global node ID
 * @param[out] gdt_value global
520
521
 *             dict, stores local node IDs (due to split) across partitions for
 *             a given global node ID.
522
 * @param[in] node_map keeps track of range of local node IDs (consecutive)
523
 *            given to the nodes in the partitions.
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
 * @param[out] lr 1-level tree marking for local split nodes.
 * @param[in] lrtensor global (all the partitions) 1-level tree.
 * @param[in] num_nodes number of nodes in current partition.
 * @param[in] nc number of partitions/communities.
 * @param[in] c current partition/community.
 * @param[in] feat_size node feature vector size.
 * @param[out] labels local (for this partition) labels.
 * @param[out] trainm local (for this partition) training nodes.
 * @param[out] testm local (for this partition) testing nodes.
 * @param[out] valm local (for this partition) validation nodes.
 * @param[in] glabels global (input graph) labels.
 * @param[in] gtrainm glabal (input graph) training nodes.
 * @param[in] gtestm glabal (input graph) testing nodes.
 * @param[in] gvalm glabal (input graph) validation nodes.
 * @param[out] Nn number of nodes in the input graph.
539
 */
540
template <typename IdType, typename IdType2, typename DType>
541
void Libra2dglBuildAdjlist(
542
543
544
545
546
547
548
549
550
    NDArray feat, NDArray gfeat, NDArray adj, NDArray inner_node,
    NDArray ldt_key, NDArray gdt_key, NDArray gdt_value, NDArray node_map,
    NDArray lr, NDArray lrtensor, int64_t num_nodes, int32_t nc, int32_t c,
    int32_t feat_size, NDArray labels, NDArray trainm, NDArray testm,
    NDArray valm, NDArray glabels, NDArray gtrainm, NDArray gtestm,
    NDArray gvalm, int64_t Nn) {
  DType *feat_ptr = feat.Ptr<DType>();    // 2D tensor
  DType *gfeat_ptr = gfeat.Ptr<DType>();  // 2D tensor
  int64_t *adj_ptr = adj.Ptr<int64_t>();  // 2D tensor
551
  int32_t *inner_node_ptr = inner_node.Ptr<int32_t>();
552
553
554
555
556
557
  int64_t *ldt_key_ptr = ldt_key.Ptr<int64_t>();
  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();
  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor
  int64_t *node_map_ptr = node_map.Ptr<int64_t>();
  int64_t *lr_ptr = lr.Ptr<int64_t>();
  int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>();
558
559
  int32_t width = nc - 1;

560
561
  runtime::parallel_for(0, num_nodes, [&](int64_t s, int64_t e) {
    for (int64_t i = s; i < e; i++) {
562
563
564
565
      int64_t k = ldt_key_ptr[i];
      int64_t v = i;
      int64_t ind = gdt_key_ptr[k];

566
      int64_t *adj_ptr_ptr = adj_ptr + v * width;
567
      if (ind == 1) {
568
        for (int32_t j = 0; j < width; j++) adj_ptr_ptr[j] = -1;
569
570
571
572
        inner_node_ptr[i] = 1;
        lr_ptr[i] = -200;
      } else {
        lr_ptr[i] = lrtensor_ptr[k];
573
        int64_t *ptr = gdt_value_ptr + k * nc;
574
575
576
        int64_t pos = 0;
        CHECK(ind <= nc);
        int32_t flg = 0;
577
        for (int64_t j = 0; j < ind; j++) {
578
          if (ptr[j] == lr_ptr[i]) flg = 1;
579
          if (c != Ver2partition<int64_t>(ptr[j], node_map_ptr, nc))
580
581
582
583
584
585
586
587
588
589
590
            adj_ptr_ptr[pos++] = ptr[j];
        }
        CHECK_EQ(flg, 1);
        CHECK(pos == ind - 1);
        for (; pos < width; pos++) adj_ptr_ptr[pos] = -1;
        inner_node_ptr[i] = 0;
      }
    }
  });

  // gather
591
592
  runtime::parallel_for(0, num_nodes, [&](int64_t s, int64_t e) {
    for (int64_t i = s; i < e; i++) {
593
      int64_t k = ldt_key_ptr[i];
594
      int64_t ind = i * feat_size;
595
      DType *optr = gfeat_ptr + ind;
596
      DType *iptr = feat_ptr + k * feat_size;
597

598
      for (int32_t j = 0; j < feat_size; j++) optr[j] = iptr[j];
599
600
601
602
603
604
605
606
607
608
609
    }

    IdType *labels_ptr = labels.Ptr<IdType>();
    IdType *glabels_ptr = glabels.Ptr<IdType>();
    IdType2 *trainm_ptr = trainm.Ptr<IdType2>();
    IdType2 *gtrainm_ptr = gtrainm.Ptr<IdType2>();
    IdType2 *testm_ptr = testm.Ptr<IdType2>();
    IdType2 *gtestm_ptr = gtestm.Ptr<IdType2>();
    IdType2 *valm_ptr = valm.Ptr<IdType2>();
    IdType2 *gvalm_ptr = gvalm.Ptr<IdType2>();

610
    for (int64_t i = 0; i < num_nodes; i++) {
611
      int64_t k = ldt_key_ptr[i];
612
      CHECK(k >= 0 && k < Nn);
613
614
615
616
617
618
619
620
621
      glabels_ptr[i] = labels_ptr[k];
      gtrainm_ptr[i] = trainm_ptr[k];
      gtestm_ptr[i] = testm_ptr[k];
      gvalm_ptr[i] = valm_ptr[k];
    }
  });
}

DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildAdjlist")
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      NDArray feat = args[0];
      NDArray gfeat = args[1];
      NDArray adj = args[2];
      NDArray inner_node = args[3];
      NDArray ldt_key = args[4];
      NDArray gdt_key = args[5];
      NDArray gdt_value = args[6];
      NDArray node_map = args[7];
      NDArray lr = args[8];
      NDArray lrtensor = args[9];
      int64_t num_nodes = args[10];
      int32_t nc = args[11];
      int32_t c = args[12];
      int32_t feat_size = args[13];
      NDArray labels = args[14];
      NDArray trainm = args[15];
      NDArray testm = args[16];
      NDArray valm = args[17];
      NDArray glabels = args[18];
      NDArray gtrainm = args[19];
      NDArray gtestm = args[20];
      NDArray gvalm = args[21];
      int64_t Nn = args[22];

      ATEN_FLOAT_TYPE_SWITCH(feat->dtype, DType, "Features", {
        ATEN_ID_TYPE_SWITCH(trainm->dtype, IdType2, {
649
          ATEN_ID_BITS_SWITCH((glabels->dtype).bits, IdType, {
650
651
652
653
654
            Libra2dglBuildAdjlist<IdType, IdType2, DType>(
                feat, gfeat, adj, inner_node, ldt_key, gdt_key, gdt_value,
                node_map, lr, lrtensor, num_nodes, nc, c, feat_size, labels,
                trainm, testm, valm, glabels, gtrainm, gtestm, gvalm, Nn);
          });
655
        });
656
      });
657
658
659
660
    });

}  // namespace aten
}  // namespace dgl