heterograph.cc 21.6 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2019 by Contributors
 * \file graph/heterograph.cc
 * \brief Heterograph implementation
 */
#include "./heterograph.h"
Minjie Wang's avatar
Minjie Wang committed
7
#include <dgl/array.h>
8
#include <dgl/immutable_graph.h>
9
#include <dgl/graph_serializer.h>
10
11
#include <dmlc/memory_io.h>
#include <memory>
12
13
14
#include <vector>
#include <tuple>
#include <utility>
15
16
17
18
19
20

using namespace dgl::runtime;

namespace dgl {
namespace {

21
22
using dgl::ImmutableGraph;

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
HeteroSubgraph EdgeSubgraphPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
    << "Invalid input: the input list size must be the same as the number of edge type.";
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
  // When preserve_nodes is true, simply compute EdgeSubgraph for each bipartite
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    const auto& rel_vsg = hg->GetRelationGraph(etype)->EdgeSubgraph(
        {eids[etype]}, true);
    subrels[etype] = rel_vsg.graph;
    ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];
    ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];
  }
42
43
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      hg->meta_graph(), subrels, hg->NumVerticesPerType()));
44
45
46
47
48
  return ret;
}

HeteroSubgraph EdgeSubgraphNoPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
49
50
  // TODO(minjie): In general, all relabeling should be separated with subgraph
  //   operations.
51
52
53
54
55
56
57
58
59
60
61
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
    << "Invalid input: the input list size must be the same as the number of edge type.";
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
  // NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite complicated in
  // heterograph. This is because we need to make sure bipartite graphs that incident
  // on the same vertex type must have the same ID space. For example, suppose we have
  // following heterograph:
  //
  // Meta graph: A -> B -> C
Minjie Wang's avatar
Minjie Wang committed
62
  // UnitGraph graphs:
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  // * A -> B: (0, 0), (0, 1)
  // * B -> C: (1, 0), (1, 1)
  //
  // Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1, 0). We need
  // to make sure that in the result subgraph, node type B still has two nodes. This means
  // we cannot simply compute EdgeSubgraph for B->C which will relabel node#1 of type B to be
  // node #0.
  //
  // One implementation is as follows:
  // (1) For each bipartite graph, slice out the edges using the given eids.
  // (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the vertex type
  //     and the value is the incident nodes from the bipartite graphs that has the vertex
  //     type as either srctype or dsttype.
  // (3) Then for each vertex type, use aten::Relabel_ on its vector<IdArray>.
  //     aten::Relabel_ computes the union of the vertex sets and relabel
  //     the unique elements from zero. The returned mapping array is the final induced
  //     vertex set for that vertex type.
  // (4) Use the relabeled edges to construct the bipartite graph.
  // step (1) & (2)
  std::vector<EdgeArray> subedges(hg->NumEdgeTypes());
  std::vector<std::vector<IdArray>> vtype2incnodes(hg->NumVertexTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    auto earray = hg->GetRelationGraph(etype)->FindEdges(0, eids[etype]);
    vtype2incnodes[src_vtype].push_back(earray.src);
    vtype2incnodes[dst_vtype].push_back(earray.dst);
    subedges[etype] = earray;
  }
  // step (3)
94
  std::vector<int64_t> num_vertices_per_type(hg->NumVertexTypes());
95
96
  for (dgl_type_t vtype = 0; vtype < hg->NumVertexTypes(); ++vtype) {
    ret.induced_vertices[vtype] = aten::Relabel_(vtype2incnodes[vtype]);
97
    num_vertices_per_type[vtype] = ret.induced_vertices[vtype]->shape[0];
98
99
100
101
102
103
104
  }
  // step (4)
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
Minjie Wang's avatar
Minjie Wang committed
105
106
    subrels[etype] = UnitGraph::CreateFromCOO(
      (src_vtype == dst_vtype)? 1 : 2,
107
108
109
110
111
      ret.induced_vertices[src_vtype]->shape[0],
      ret.induced_vertices[dst_vtype]->shape[0],
      subedges[etype].src,
      subedges[etype].dst);
  }
112
113
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      hg->meta_graph(), subrels, std::move(num_vertices_per_type)));
114
115
116
  return ret;
}

117
void HeteroGraphSanityCheck(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
118
119
120
  // Sanity check
  CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
  CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
Minjie Wang's avatar
Minjie Wang committed
121
  // all relation graphs must have only one edge type
122
  for (const auto &rg : rel_graphs) {
Minjie Wang's avatar
Minjie Wang committed
123
    CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
124
  }
125
126
127
128
  auto ctx = rel_graphs[0]->Context();
  for (const auto &rg : rel_graphs) {
    CHECK_EQ(rg->Context(), ctx) << "Each relation graph must have the same context.";
  }
129
130
131
132
}

std::vector<int64_t>
InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
133
  // create num verts per type
134
  std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);
135
136
137
138
139
140
141
142
143

  EdgeArray etype_array = meta_graph->Edges();
  dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data);
  dgl_type_t *dsttypes = static_cast<dgl_type_t *>(etype_array.dst->data);
  dgl_type_t *etypes = static_cast<dgl_type_t *>(etype_array.id->data);
  for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
    dgl_type_t srctype = srctypes[i];
    dgl_type_t dsttype = dsttypes[i];
    dgl_type_t etype = etypes[i];
Minjie Wang's avatar
Minjie Wang committed
144
145
146
    const auto& rg = rel_graphs[etype];
    const auto sty = 0;
    const auto dty = rg->NumVertexTypes() == 1? 0 : 1;
147
148
149
    size_t nv;

    // # nodes of source type
Minjie Wang's avatar
Minjie Wang committed
150
    nv = rg->NumVertices(sty);
151
152
    if (num_verts_per_type[srctype] < 0)
      num_verts_per_type[srctype] = nv;
153
    else
154
      CHECK_EQ(num_verts_per_type[srctype], nv)
155
156
        << "Mismatch number of vertices for vertex type " << srctype;
    // # nodes of destination type
Minjie Wang's avatar
Minjie Wang committed
157
    nv = rg->NumVertices(dty);
158
159
    if (num_verts_per_type[dsttype] < 0)
      num_verts_per_type[dsttype] = nv;
160
    else
161
      CHECK_EQ(num_verts_per_type[dsttype], nv)
162
        << "Mismatch number of vertices for vertex type " << dsttype;
163
  }
164
165
  return num_verts_per_type;
}
166

167
168
std::vector<UnitGraphPtr> CastToUnitGraphs(const std::vector<HeteroGraphPtr>& rel_graphs) {
  std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());
169
170
171
  for (size_t i = 0; i < rel_graphs.size(); ++i) {
    HeteroGraphPtr relg = rel_graphs[i];
    if (std::dynamic_pointer_cast<UnitGraph>(relg)) {
172
      relation_graphs[i] = std::dynamic_pointer_cast<UnitGraph>(relg);
173
    } else {
174
      relation_graphs[i] = CHECK_NOTNULL(
175
176
177
          std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0)));
    }
  }
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
  return relation_graphs;
}

}  // namespace

HeteroGraph::HeteroGraph(
    GraphPtr meta_graph,
    const std::vector<HeteroGraphPtr>& rel_graphs,
    const std::vector<int64_t>& num_nodes_per_type) : BaseHeteroGraph(meta_graph) {
  if (num_nodes_per_type.size() == 0)
    num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
  else
    num_verts_per_type_ = num_nodes_per_type;
  HeteroGraphSanityCheck(meta_graph, rel_graphs);
  relation_graphs_ = CastToUnitGraphs(rel_graphs);
193
194
195
}

bool HeteroGraph::IsMultigraph() const {
196
197
198
199
200
201
  for (const auto &hg : relation_graphs_) {
    if (hg->IsMultigraph()) {
      return true;
    }
  }
  return false;
202
203
204
}

BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
205
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
206
207
208
209
210
211
212
213
  return aten::LT(vids, NumVertices(vtype));
}

HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
  CHECK_EQ(vids.size(), NumVertexTypes())
    << "Invalid input: the input list size must be the same as the number of vertex types.";
  HeteroSubgraph ret;
  ret.induced_vertices = vids;
214
215
216
  std::vector<int64_t> num_vertices_per_type(NumVertexTypes());
  for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
    num_vertices_per_type[vtype] = vids[vtype]->shape[0];
217
218
219
220
221
222
  ret.induced_edges.resize(NumEdgeTypes());
  std::vector<HeteroGraphPtr> subrels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
    auto pair = meta_graph_->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
Minjie Wang's avatar
Minjie Wang committed
223
224
225
226
    const std::vector<IdArray> rel_vids = (src_vtype == dst_vtype) ?
      std::vector<IdArray>({vids[src_vtype]}) :
      std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
    const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
227
228
229
    subrels[etype] = rel_vsg.graph;
    ret.induced_edges[etype] = rel_vsg.induced_edges[0];
  }
230
231
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      meta_graph_, subrels, std::move(num_vertices_per_type)));
232
233
234
235
236
237
238
239
240
241
242
243
  return ret;
}

HeteroSubgraph HeteroGraph::EdgeSubgraph(
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
  if (preserve_nodes) {
    return EdgeSubgraphPreserveNodes(this, eids);
  } else {
    return EdgeSubgraphNoPreserveNodes(this, eids);
  }
}

244
245
246
247
248
249
250
251
252
253
254
HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
    rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));
  }
  return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
                                        hgindex->num_verts_per_type_));
}

255
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) {
256
257
258
259
260
261
262
  if (ctx == g->Context()) {
    return g;
  }
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
263
    rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
264
265
266
267
268
  }
  return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
                                        hgindex->num_verts_per_type_));
}

269
270
271
272
273
274
275
276
277
278
void HeteroGraph::PinMemory_() {
  for (auto g : relation_graphs_)
    g->PinMemory_();
}

void HeteroGraph::UnpinMemory_() {
  for (auto g : relation_graphs_)
    g->UnpinMemory_();
}

279
280
281
282
283
void HeteroGraph::RecordStream(DGLStreamHandle stream) {
  for (auto g : relation_graphs_)
    g->RecordStream(stream);
}

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
std::string HeteroGraph::SharedMemName() const {
  return shared_mem_ ? shared_mem_->GetName() : "";
}

HeteroGraphPtr HeteroGraph::CopyToSharedMem(
      HeteroGraphPtr g, const std::string& name, const std::vector<std::string>& ntypes,
      const std::vector<std::string>& etypes, const std::set<std::string>& fmts) {
  // TODO(JJ): Raise error when calling shared_memory if graph index is on gpu
  auto hg = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hg);
  if (hg->SharedMemName() == name)
    return g;

  // Copy buffer to share memory
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->CreateNew(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  bool has_coo = fmts.find("coo") != fmts.end();
  bool has_csr = fmts.find("csr") != fmts.end();
  bool has_csc = fmts.find("csc") != fmts.end();
  shm.Write(g->NumBits());
  shm.Write(has_coo);
  shm.Write(has_csr);
  shm.Write(has_csc);
  shm.Write(ImmutableGraph::ToImmutable(hg->meta_graph_));
  shm.Write(hg->num_verts_per_type_);

  std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());

  for (dgl_type_t etype = 0 ; etype < g->NumEdgeTypes() ; ++etype) {
316
317
    auto src_dst_type = g->GetEndpointTypes(etype);
    int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2);
318
319
320
321
322
323
324
325
326
327
328
329
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      coo = shm.CopyToSharedMem(hg->GetCOOMatrix(etype), prefix + "_coo");
    }
    if (has_csr) {
      csr = shm.CopyToSharedMem(hg->GetCSRMatrix(etype), prefix + "_csr");
    }
    if (has_csc) {
      csc = shm.CopyToSharedMem(hg->GetCSCMatrix(etype), prefix + "_csc");
    }
330
331
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
332
333
334
335
336
337
338
339
340
341
342
343
344
  }

  auto ret = std::shared_ptr<HeteroGraph>(
      new HeteroGraph(hg->meta_graph_, relgraphs, hg->num_verts_per_type_));
  ret->shared_mem_ = mem;

  shm.Write(ntypes);
  shm.Write(etypes);
  return ret;
}

std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
    HeteroGraph::CreateFromSharedMem(const std::string &name) {
345
346
347
348
  bool exist = SharedMemory::Exist(name);
  if (!exist) {
    return std::make_tuple(nullptr, std::vector<std::string>(), std::vector<std::string>());
  }
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  uint8_t nbits;
  CHECK(shm.Read(&nbits)) << "invalid nbits (unit8_t)";

  bool has_coo, has_csr, has_csc;
  CHECK(shm.Read(&has_coo)) << "invalid nbits (unit8_t)";
  CHECK(shm.Read(&has_csr)) << "invalid csr (unit8_t)";
  CHECK(shm.Read(&has_csc)) << "invalid csc (unit8_t)";

  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(shm.Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;

  std::vector<int64_t> num_verts_per_type;
  CHECK(shm.Read(&num_verts_per_type)) << "Invalid number of vertices per type";

  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  for (dgl_type_t etype = 0 ; etype < metagraph->NumEdges() ; ++etype) {
371
372
    auto src_dst = metagraph->FindEdge(etype);
    int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2;
373
374
375
376
377
378
379
380
381
382
383
384
385
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      shm.CreateFromSharedMem(&coo, prefix + "_coo");
    }
    if (has_csr) {
      shm.CreateFromSharedMem(&csr, prefix + "_csr");
    }
    if (has_csc) {
      shm.CreateFromSharedMem(&csc, prefix + "_csc");
    }

386
387
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
388
389
390
391
392
393
394
395
396
397
398
399
  }

  auto ret = std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
  ret->shared_mem_ = mem;

  std::vector<std::string> ntypes;
  std::vector<std::string> etypes;
  CHECK(shm.Read(&ntypes)) << "invalid ntypes";
  CHECK(shm.Read(&etypes)) << "invalid etypes";
  return std::make_tuple(ret, ntypes, etypes);
}

400
HeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const {
401
402
403
  std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
    auto relgraph = std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
404
    format_rels[etype] = relgraph->GetGraphInFormat(formats);
405
406
407
408
409
  }
  return HeteroGraphPtr(new HeteroGraph(
    meta_graph_, format_rels, NumVerticesPerType()));
}

410
411
412
413
414
FlattenedHeteroGraphPtr HeteroGraph::Flatten(
    const std::vector<dgl_type_t>& etypes) const {
  const int64_t bits = NumBits();
  if (bits == 32) {
    return FlattenImpl<int32_t>(etypes);
415
  } else {
416
417
418
419
420
421
    return FlattenImpl<int64_t>(etypes);
  }
}

template <class IdType>
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& etypes) const {
Minjie Wang's avatar
Minjie Wang committed
422
423
  std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
  size_t src_nodes = 0, dst_nodes = 0;
424
425
  std::vector<dgl_type_t> induced_srctype, induced_dsttype;
  std::vector<IdType> induced_srcid, induced_dstid;
Minjie Wang's avatar
Minjie Wang committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
  std::vector<dgl_type_t> srctype_set, dsttype_set;

  // XXXtype_offsets contain the mapping from node type and number of nodes after this
  // loop.
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t num_srctype_nodes = NumVertices(srctype);
    size_t num_dsttype_nodes = NumVertices(dsttype);

    if (srctype_offsets.count(srctype) == 0) {
      srctype_offsets[srctype] = num_srctype_nodes;
      srctype_set.push_back(srctype);
    }
    if (dsttype_offsets.count(dsttype) == 0) {
      dsttype_offsets[dsttype] = num_dsttype_nodes;
      dsttype_set.push_back(dsttype);
    }
  }
446
  // Sort the node types so that we can compare the sets and decide whether a homogeneous graph
Minjie Wang's avatar
Minjie Wang committed
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
  // should be returned.
  std::sort(srctype_set.begin(), srctype_set.end());
  std::sort(dsttype_set.begin(), dsttype_set.end());
  bool homograph = (srctype_set.size() == dsttype_set.size()) &&
    std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());

  // XXXtype_offsets contain the mapping from node type to node ID offsets after these
  // two loops.
  for (size_t i = 0; i < srctype_set.size(); ++i) {
    dgl_type_t ntype = srctype_set[i];
    size_t num_nodes = srctype_offsets[ntype];
    srctype_offsets[ntype] = src_nodes;
    src_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_srctype.push_back(ntype);
      induced_srcid.push_back(j);
    }
  }
  for (size_t i = 0; i < dsttype_set.size(); ++i) {
    dgl_type_t ntype = dsttype_set[i];
    size_t num_nodes = dsttype_offsets[ntype];
    dsttype_offsets[ntype] = dst_nodes;
    dst_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_dsttype.push_back(ntype);
      induced_dstid.push_back(j);
    }
  }
475

476
477
478
479
480
481
482
  // TODO(minjie): Using concat operations cause many fragmented memory.
  //   Need to optimize it in the future.
  std::vector<IdArray> src_arrs, dst_arrs, eid_arrs, induced_etypes;
  src_arrs.reserve(etypes.size());
  dst_arrs.reserve(etypes.size());
  eid_arrs.reserve(etypes.size());
  induced_etypes.reserve(etypes.size());
Minjie Wang's avatar
Minjie Wang committed
483
484
485
486
487
488
489
490
491
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t srctype_offset = srctype_offsets[srctype];
    size_t dsttype_offset = dsttype_offsets[dsttype];

    EdgeArray edges = Edges(etype);
    size_t num_edges = NumEdges(etype);
492
493
494
495
    src_arrs.push_back(edges.src + srctype_offset);
    dst_arrs.push_back(edges.dst + dsttype_offset);
    eid_arrs.push_back(edges.id);
    induced_etypes.push_back(aten::Full(etype, num_edges, NumBits(), Context()));
Minjie Wang's avatar
Minjie Wang committed
496
497
498
499
500
501
  }

  HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
      homograph ? 1 : 2,
      src_nodes,
      dst_nodes,
502
503
504
505
506
507
      aten::Concat(src_arrs),
      aten::Concat(dst_arrs));

  // Sanity check
  CHECK_EQ(gptr->Context(), Context());
  CHECK_EQ(gptr->NumBits(), NumBits());
Minjie Wang's avatar
Minjie Wang committed
508
509

  FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
510
  result->graph = HeteroGraphRef(HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));
511
512
513
514
515
516
517
518
519
  result->induced_srctype = aten::VecToIdArray(induced_srctype).CopyTo(Context());
  result->induced_srctype_set = aten::VecToIdArray(srctype_set).CopyTo(Context());
  result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());
  result->induced_etype = aten::Concat(induced_etypes);
  result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context());
  result->induced_eid = aten::Concat(eid_arrs);
  result->induced_dsttype = aten::VecToIdArray(induced_dsttype).CopyTo(Context());
  result->induced_dsttype_set = aten::VecToIdArray(dsttype_set).CopyTo(Context());
  result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context());
Minjie Wang's avatar
Minjie Wang committed
520
  return FlattenedHeteroGraphPtr(result);
521
522
}

523
524
525
526
527
528
constexpr uint64_t kDGLSerialize_HeteroGraph = 0xDD589FBE35224ABF;

bool HeteroGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
529
530
531
532
533
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
  meta_graph_ = meta_imgraph;
  CHECK(fs->Read(&relation_graphs_)) << "Invalid relation_graphs_";
  CHECK(fs->Read(&num_verts_per_type_)) << "Invalid num_verts_per_type_";
534
535
536
537
538
539
  return true;
}

void HeteroGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_HeteroGraph);
  auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
540
541
542
  fs->Write(meta_graph_ptr);
  fs->Write(relation_graphs_);
  fs->Write(num_verts_per_type_);
543
544
}

545
546
547
548
549
550
551
552
GraphPtr HeteroGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "graph has more than one node types";
  CHECK(NumEdgeTypes() == 1) << "graph has more than one edge types";
  auto unit_graph = CHECK_NOTNULL(
      std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));
  return unit_graph->AsImmutableGraph();
}

553
554
555
556
557
558
559
560
561
562
563
564
HeteroGraphPtr HeteroGraph::LineGraph(bool backtracking) const {
  CHECK_EQ(1, meta_graph_->NumEdges()) << "Only support Homogeneous graph now (one edge type)";
  CHECK_EQ(1, meta_graph_->NumVertices()) << "Only support Homogeneous graph now (one node type)";
  CHECK_EQ(1, relation_graphs_.size()) << "Only support Homogeneous graph now";
  UnitGraphPtr ug = relation_graphs_[0];

  const auto &ulg = ug->LineGraph(backtracking);
  std::vector<HeteroGraphPtr> rel_graph = {ulg};
  std::vector<int64_t> num_nodes_per_type = {static_cast<int64_t>(ulg->NumVertices(0))};
  return HeteroGraphPtr(new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type)));
}

565
}  // namespace dgl