"cmake/vscode:/vscode.git/clone" did not exist on "69a532c1aba4bac714cce6746b610ba2b20b835c"
union_partition.cc 17.9 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file graph/transform/union_partition.cc
 * @brief Functions for partition, union multiple graphs.
5
6
7
8
9
10
 */
#include "../heterograph.h"
using namespace dgl::runtime;

namespace dgl {

11
HeteroGraphPtr JointUnionHeteroGraph(
12
13
14
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
  CHECK_GT(component_graphs.size(), 0)
      << "Input graph list has at least two graphs";
15
16
17
18
19
20
21
22
23
24
25
26
27
  std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
  std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);

  // Loop over all canonical etypes
  for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
    auto pair = meta_graph->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    uint64_t num_src_v = component_graphs[0]->NumVertices(src_vtype);
    uint64_t num_dst_v = component_graphs[0]->NumVertices(dst_vtype);
    HeteroGraphPtr rgptr = nullptr;

    // ALL = CSC | CSR | COO
28
29
    const dgl_format_code_t code =
        component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
30

31
32
33
    // get common format
    for (size_t i = 0; i < component_graphs.size(); ++i) {
      const auto& cg = component_graphs[i];
34
35
36
37
38
39
40
41
42
      CHECK_EQ(num_src_v, component_graphs[i]->NumVertices(src_vtype))
          << "Input graph[" << i
          << "] should have same number of src vertices as input graph[0]";
      CHECK_EQ(num_dst_v, component_graphs[i]->NumVertices(dst_vtype))
          << "Input graph[" << i
          << "] should have same number of dst vertices as input graph[0]";

      const dgl_format_code_t curr_code =
          cg->GetRelationGraph(etype)->GetAllowedFormats();
43
44
      if (curr_code != code)
        LOG(FATAL) << "All components should have the same formats";
45
46
47
    }

    // prefer COO
48
    if (FORMAT_HAS_COO(code)) {
49
50
51
52
53
54
55
56
      std::vector<aten::COOMatrix> coos;
      for (size_t i = 0; i < component_graphs.size(); ++i) {
        const auto& cg = component_graphs[i];
        aten::COOMatrix coo = cg->GetCOOMatrix(etype);
        coos.push_back(coo);
      }

      aten::COOMatrix res = aten::UnionCoo(coos);
57
58
      rgptr =
          UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
59
    } else if (FORMAT_HAS_CSR(code)) {
60
61
62
63
64
65
66
67
      std::vector<aten::CSRMatrix> csrs;
      for (size_t i = 0; i < component_graphs.size(); ++i) {
        const auto& cg = component_graphs[i];
        aten::CSRMatrix csr = cg->GetCSRMatrix(etype);
        csrs.push_back(csr);
      }

      aten::CSRMatrix res = aten::UnionCsr(csrs);
68
69
      rgptr =
          UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
70
    } else if (FORMAT_HAS_CSC(code)) {
71
72
73
74
75
76
77
78
79
      // CSR and CSC have the same storage format, i.e. CSRMatrix
      std::vector<aten::CSRMatrix> cscs;
      for (size_t i = 0; i < component_graphs.size(); ++i) {
        const auto& cg = component_graphs[i];
        aten::CSRMatrix csc = cg->GetCSCMatrix(etype);
        cscs.push_back(csc);
      }

      aten::CSRMatrix res = aten::UnionCsr(cscs);
80
81
      rgptr =
          UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
82
83
84
85
86
87
88
    }

    rel_graphs[etype] = rgptr;
    num_nodes_per_type[src_vtype] = num_src_v;
    num_nodes_per_type[dst_vtype] = num_dst_v;
  }

89
90
  return CreateHeteroGraph(
      meta_graph, rel_graphs, std::move(num_nodes_per_type));
91
92
}

93
94
95
96
97
98
HeteroGraphPtr DisjointUnionHeteroGraph2(
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
  CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
  std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
  std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);

99
100
101
  // Loop over all ntypes
  for (dgl_type_t vtype = 0; vtype < meta_graph->NumVertices(); ++vtype) {
    uint64_t offset = 0;
102
    for (const auto& cg : component_graphs) offset += cg->NumVertices(vtype);
103
104
105
    num_nodes_per_type[vtype] = offset;
  }

106
107
108
109
110
111
112
  // Loop over all canonical etypes
  for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
    auto pair = meta_graph->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    HeteroGraphPtr rgptr = nullptr;

113
114
    const dgl_format_code_t code =
        component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
115
    // do some preprocess
116
117
118
    for (const auto& cg : component_graphs) {
      const dgl_format_code_t cur_code =
          cg->GetRelationGraph(etype)->GetAllowedFormats();
119
120
      if (cur_code != code)
        LOG(FATAL) << "All components should have the same formats";
121
122
123
    }

    // prefer COO
124
    if (FORMAT_HAS_COO(code)) {
125
      std::vector<aten::COOMatrix> coos;
126
      for (const auto& cg : component_graphs) {
127
128
129
130
131
132
        aten::COOMatrix coo = cg->GetCOOMatrix(etype);
        coos.push_back(coo);
      }

      aten::COOMatrix res = aten::DisjointUnionCoo(coos);

133
134
      rgptr =
          UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
135
    } else if (FORMAT_HAS_CSR(code)) {
136
      std::vector<aten::CSRMatrix> csrs;
137
      for (const auto& cg : component_graphs) {
138
139
140
141
142
143
        aten::CSRMatrix csr = cg->GetCSRMatrix(etype);
        csrs.push_back(csr);
      }

      aten::CSRMatrix res = aten::DisjointUnionCsr(csrs);

144
145
      rgptr =
          UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
146
    } else if (FORMAT_HAS_CSC(code)) {
147
      // CSR and CSC have the same storage format, i.e. CSRMatrix
148
      std::vector<aten::CSRMatrix> cscs;
149
      for (const auto& cg : component_graphs) {
150
151
152
153
154
        aten::CSRMatrix csc = cg->GetCSCMatrix(etype);
        cscs.push_back(csc);
      }

      aten::CSRMatrix res = aten::DisjointUnionCsr(cscs);
155
156
      rgptr =
          UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
157
158
159
160
    }
    rel_graphs[etype] = rgptr;
  }

161
162
  return CreateHeteroGraph(
      meta_graph, rel_graphs, std::move(num_nodes_per_type));
163
164
165
}

std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
166
167
    GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
    IdArray edge_sizes) {
168
  // Sanity check for vertex sizes
169
170
  CHECK_EQ(vertex_sizes->dtype.bits, 64)
      << "dtype of vertex_sizes should be int64";
171
172
  CHECK_EQ(edge_sizes->dtype.bits, 64) << "dtype of edge_sizes should be int64";
  const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
173
174
  const uint64_t* vertex_sizes_data =
      static_cast<uint64_t*>(vertex_sizes->data);
175
176
177
178
179
180
181
182
183
184
185
186
  const uint64_t num_vertex_types = meta_graph->NumVertices();
  const uint64_t batch_size = len_vertex_sizes / num_vertex_types;

  // Map vertex type to the corresponding node cum sum
  std::vector<std::vector<uint64_t>> vertex_cumsum;
  vertex_cumsum.resize(num_vertex_types);
  // Loop over all vertex types
  for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {
    vertex_cumsum[vtype].push_back(0);
    for (uint64_t g = 0; g < batch_size; ++g) {
      // We've flattened the number of vertices in the batch for all types
      vertex_cumsum[vtype].push_back(
187
          vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
188
    }
189
190
191
192
    CHECK_EQ(
        vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
        << "Sum of the given sizes must equal to the number of nodes for type "
        << vtype;
193
194
195
196
197
198
199
200
201
202
203
204
205
206
  }

  // Sanity check for edge sizes
  const uint64_t* edge_sizes_data = static_cast<uint64_t*>(edge_sizes->data);
  const uint64_t num_edge_types = meta_graph->NumEdges();
  // Map edge type to the corresponding edge cum sum
  std::vector<std::vector<uint64_t>> edge_cumsum;
  edge_cumsum.resize(num_edge_types);
  // Loop over all edge types
  for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
    edge_cumsum[etype].push_back(0);
    for (uint64_t g = 0; g < batch_size; ++g) {
      // We've flattened the number of edges in the batch for all types
      edge_cumsum[etype].push_back(
207
          edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
208
209
    }
    CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
210
211
        << "Sum of the given sizes must equal to the number of edges for type "
        << etype;
212
213
214
215
216
217
  }

  // Construct relation graphs for unbatched graphs
  std::vector<std::vector<HeteroGraphPtr>> rel_graphs;
  rel_graphs.resize(batch_size);
  // Loop over all edge types
218
  auto code = batched_graph->GetRelationGraph(0)->GetAllowedFormats();
219

220
  if (FORMAT_HAS_COO(code)) {
221
222
223
224
225
    for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
      auto pair = meta_graph->FindEdge(etype);
      const dgl_type_t src_vtype = pair.first;
      const dgl_type_t dst_vtype = pair.second;
      aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);
226
227
228
      auto res = aten::DisjointPartitionCooBySizes(
          coo, batch_size, edge_cumsum[etype], vertex_cumsum[src_vtype],
          vertex_cumsum[dst_vtype]);
229
230
      for (uint64_t g = 0; g < batch_size; ++g) {
        HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
231
            (src_vtype == dst_vtype) ? 1 : 2, res[g], code);
232
233
234
        rel_graphs[g].push_back(rgptr);
      }
    }
235
  } else if (FORMAT_HAS_CSR(code)) {
236
237
238
239
240
    for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
      auto pair = meta_graph->FindEdge(etype);
      const dgl_type_t src_vtype = pair.first;
      const dgl_type_t dst_vtype = pair.second;
      aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);
241
242
243
      auto res = aten::DisjointPartitionCsrBySizes(
          csr, batch_size, edge_cumsum[etype], vertex_cumsum[src_vtype],
          vertex_cumsum[dst_vtype]);
244
245
      for (uint64_t g = 0; g < batch_size; ++g) {
        HeteroGraphPtr rgptr = UnitGraph::CreateFromCSR(
246
            (src_vtype == dst_vtype) ? 1 : 2, res[g], code);
247
248
249
        rel_graphs[g].push_back(rgptr);
      }
    }
250
  } else if (FORMAT_HAS_CSC(code)) {
251
252
253
254
    for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
      auto pair = meta_graph->FindEdge(etype);
      const dgl_type_t src_vtype = pair.first;
      const dgl_type_t dst_vtype = pair.second;
255
      // CSR and CSC have the same storage format, i.e. CSRMatrix
256
      aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
257
258
259
      auto res = aten::DisjointPartitionCsrBySizes(
          csc, batch_size, edge_cumsum[etype], vertex_cumsum[dst_vtype],
          vertex_cumsum[src_vtype]);
260
261
      for (uint64_t g = 0; g < batch_size; ++g) {
        HeteroGraphPtr rgptr = UnitGraph::CreateFromCSC(
262
            (src_vtype == dst_vtype) ? 1 : 2, res[g], code);
263
264
265
266
267
268
269
270
271
272
        rel_graphs[g].push_back(rgptr);
      }
    }
  }

  std::vector<HeteroGraphPtr> rst;
  std::vector<int64_t> num_nodes_per_type(num_vertex_types);
  for (uint64_t g = 0; g < batch_size; ++g) {
    for (uint64_t i = 0; i < num_vertex_types; ++i)
      num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
273
274
    rst.push_back(
        CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
275
276
277
278
  }
  return rst;
}

279
HeteroGraphPtr SliceHeteroGraph(
280
281
282
    GraphPtr meta_graph, HeteroGraphPtr batched_graph,
    IdArray num_nodes_per_type, IdArray start_nid_per_type,
    IdArray num_edges_per_type, IdArray start_eid_per_type) {
283
284
  std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());

285
286
287
288
289
290
291
292
  const uint64_t* start_nid_per_type_data =
      static_cast<uint64_t*>(start_nid_per_type->data);
  const uint64_t* num_nodes_per_type_data =
      static_cast<uint64_t*>(num_nodes_per_type->data);
  const uint64_t* start_eid_per_type_data =
      static_cast<uint64_t*>(start_eid_per_type->data);
  const uint64_t* num_edges_per_type_data =
      static_cast<uint64_t*>(num_edges_per_type->data);
293
294
295
296
297
298
299
300
301

  // Map vertex type to the corresponding node range
  const uint64_t num_vertex_types = meta_graph->NumVertices();
  std::vector<std::vector<uint64_t>> vertex_range;
  vertex_range.resize(num_vertex_types);
  // Loop over all vertex types
  for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {
    vertex_range[vtype].push_back(start_nid_per_type_data[vtype]);
    vertex_range[vtype].push_back(
302
        start_nid_per_type_data[vtype] + num_nodes_per_type_data[vtype]);
303
304
305
306
307
308
309
310
  }

  // Loop over all canonical etypes
  for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
    auto pair = meta_graph->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    HeteroGraphPtr rgptr = nullptr;
311
312
    const dgl_format_code_t code =
        batched_graph->GetRelationGraph(etype)->GetAllowedFormats();
313
314
315
316

    // handle graph without edges
    std::vector<uint64_t> edge_range;
    edge_range.push_back(start_eid_per_type_data[etype]);
317
318
    edge_range.push_back(
        start_eid_per_type_data[etype] + num_edges_per_type_data[etype]);
319
320
321
322

    // prefer COO
    if (FORMAT_HAS_COO(code)) {
      aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);
323
324
325
326
      aten::COOMatrix res = aten::COOSliceContiguousChunk(
          coo, edge_range, vertex_range[src_vtype], vertex_range[dst_vtype]);
      rgptr =
          UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
327
328
    } else if (FORMAT_HAS_CSR(code)) {
      aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);
329
330
331
332
      aten::CSRMatrix res = aten::CSRSliceContiguousChunk(
          csr, edge_range, vertex_range[src_vtype], vertex_range[dst_vtype]);
      rgptr =
          UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
333
334
335
    } else if (FORMAT_HAS_CSC(code)) {
      // CSR and CSC have the same storage format, i.e. CSRMatrix
      aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
336
337
338
339
      aten::CSRMatrix res = aten::CSRSliceContiguousChunk(
          csc, edge_range, vertex_range[dst_vtype], vertex_range[src_vtype]);
      rgptr =
          UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
340
341
342
343
344
    }

    rel_graphs[etype] = rgptr;
  }

345
346
  return CreateHeteroGraph(
      meta_graph, rel_graphs, num_nodes_per_type.ToVector<int64_t>());
347
348
}

349
template <class IdType>
350
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
351
352
    GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
    IdArray edge_sizes) {
353
354
  // Sanity check for vertex sizes
  const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
355
356
  const uint64_t* vertex_sizes_data =
      static_cast<uint64_t*>(vertex_sizes->data);
357
358
359
360
361
362
363
364
365
366
367
  const uint64_t num_vertex_types = meta_graph->NumVertices();
  const uint64_t batch_size = len_vertex_sizes / num_vertex_types;
  // Map vertex type to the corresponding node cum sum
  std::vector<std::vector<uint64_t>> vertex_cumsum;
  vertex_cumsum.resize(num_vertex_types);
  // Loop over all vertex types
  for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {
    vertex_cumsum[vtype].push_back(0);
    for (uint64_t g = 0; g < batch_size; ++g) {
      // We've flattened the number of vertices in the batch for all types
      vertex_cumsum[vtype].push_back(
368
          vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
369
    }
370
371
372
373
    CHECK_EQ(
        vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
        << "Sum of the given sizes must equal to the number of nodes for type "
        << vtype;
374
375
376
377
378
379
380
381
382
383
384
385
386
387
  }

  // Sanity check for edge sizes
  const uint64_t* edge_sizes_data = static_cast<uint64_t*>(edge_sizes->data);
  const uint64_t num_edge_types = meta_graph->NumEdges();
  // Map edge type to the corresponding edge cum sum
  std::vector<std::vector<uint64_t>> edge_cumsum;
  edge_cumsum.resize(num_edge_types);
  // Loop over all edge types
  for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
    edge_cumsum[etype].push_back(0);
    for (uint64_t g = 0; g < batch_size; ++g) {
      // We've flattened the number of edges in the batch for all types
      edge_cumsum[etype].push_back(
388
          edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
389
390
    }
    CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
391
392
        << "Sum of the given sizes must equal to the number of edges for type "
        << etype;
393
394
395
396
397
398
399
400
401
402
403
  }

  // Construct relation graphs for unbatched graphs
  std::vector<std::vector<HeteroGraphPtr>> rel_graphs;
  rel_graphs.resize(batch_size);
  // Loop over all edge types
  for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
    auto pair = meta_graph->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    EdgeArray edges = batched_graph->Edges(etype);
404
405
    const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);
    const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);
406
407
    // Loop over all graphs to be unbatched
    for (uint64_t g = 0; g < batch_size; ++g) {
408
      std::vector<IdType> result_src, result_dst;
409
      // Loop over the chunk of edges for the specified graph and edge type
410
411
      for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1];
           ++e) {
412
413
414
415
416
        // TODO(mufei): Should use array operations to implement this.
        result_src.push_back(edges_src_data[e] - vertex_cumsum[src_vtype][g]);
        result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]);
      }
      HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
417
418
419
420
421
          (src_vtype == dst_vtype) ? 1 : 2,
          vertex_sizes_data[src_vtype * batch_size + g],
          vertex_sizes_data[dst_vtype * batch_size + g],
          aten::VecToIdArray(result_src, sizeof(IdType) * 8),
          aten::VecToIdArray(result_dst, sizeof(IdType) * 8));
422
423
424
425
426
      rel_graphs[g].push_back(rgptr);
    }
  }

  std::vector<HeteroGraphPtr> rst;
427
  std::vector<int64_t> num_nodes_per_type(num_vertex_types);
428
  for (uint64_t g = 0; g < batch_size; ++g) {
429
430
    for (uint64_t i = 0; i < num_vertex_types; ++i)
      num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
431
432
    rst.push_back(
        CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
433
434
435
436
  }
  return rst;
}

437
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int32_t>(
438
439
    GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
    IdArray edge_sizes);
440
441

template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int64_t>(
442
443
    GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
    IdArray edge_sizes);
444

445
}  // namespace dgl