heterograph.cc 21.5 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2019 by Contributors
 * \file graph/heterograph.cc
 * \brief Heterograph implementation
 */
#include "./heterograph.h"
Minjie Wang's avatar
Minjie Wang committed
7
#include <dgl/array.h>
8
#include <dgl/immutable_graph.h>
9
#include <dgl/graph_serializer.h>
10
11
#include <dmlc/memory_io.h>
#include <memory>
12
13
14
#include <vector>
#include <tuple>
#include <utility>
15
16
17
18
19
20

using namespace dgl::runtime;

namespace dgl {
namespace {

21
22
using dgl::ImmutableGraph;

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
HeteroSubgraph EdgeSubgraphPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
    << "Invalid input: the input list size must be the same as the number of edge type.";
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
  // When preserve_nodes is true, simply compute EdgeSubgraph for each bipartite
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    const auto& rel_vsg = hg->GetRelationGraph(etype)->EdgeSubgraph(
        {eids[etype]}, true);
    subrels[etype] = rel_vsg.graph;
    ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];
    ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];
  }
42
43
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      hg->meta_graph(), subrels, hg->NumVerticesPerType()));
44
45
46
47
48
  return ret;
}

HeteroSubgraph EdgeSubgraphNoPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
49
50
  // TODO(minjie): In general, all relabeling should be separated with subgraph
  //   operations.
51
52
53
54
55
56
57
58
59
60
61
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
    << "Invalid input: the input list size must be the same as the number of edge type.";
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
  // NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite complicated in
  // heterograph. This is because we need to make sure bipartite graphs that incident
  // on the same vertex type must have the same ID space. For example, suppose we have
  // following heterograph:
  //
  // Meta graph: A -> B -> C
Minjie Wang's avatar
Minjie Wang committed
62
  // UnitGraph graphs:
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  // * A -> B: (0, 0), (0, 1)
  // * B -> C: (1, 0), (1, 1)
  //
  // Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1, 0). We need
  // to make sure that in the result subgraph, node type B still has two nodes. This means
  // we cannot simply compute EdgeSubgraph for B->C which will relabel node#1 of type B to be
  // node #0.
  //
  // One implementation is as follows:
  // (1) For each bipartite graph, slice out the edges using the given eids.
  // (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the vertex type
  //     and the value is the incident nodes from the bipartite graphs that has the vertex
  //     type as either srctype or dsttype.
  // (3) Then for each vertex type, use aten::Relabel_ on its vector<IdArray>.
  //     aten::Relabel_ computes the union of the vertex sets and relabel
  //     the unique elements from zero. The returned mapping array is the final induced
  //     vertex set for that vertex type.
  // (4) Use the relabeled edges to construct the bipartite graph.
  // step (1) & (2)
  std::vector<EdgeArray> subedges(hg->NumEdgeTypes());
  std::vector<std::vector<IdArray>> vtype2incnodes(hg->NumVertexTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    auto earray = hg->GetRelationGraph(etype)->FindEdges(0, eids[etype]);
    vtype2incnodes[src_vtype].push_back(earray.src);
    vtype2incnodes[dst_vtype].push_back(earray.dst);
    subedges[etype] = earray;
  }
  // step (3)
94
  std::vector<int64_t> num_vertices_per_type(hg->NumVertexTypes());
95
96
  for (dgl_type_t vtype = 0; vtype < hg->NumVertexTypes(); ++vtype) {
    ret.induced_vertices[vtype] = aten::Relabel_(vtype2incnodes[vtype]);
97
    num_vertices_per_type[vtype] = ret.induced_vertices[vtype]->shape[0];
98
99
100
101
102
103
104
  }
  // step (4)
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
Minjie Wang's avatar
Minjie Wang committed
105
106
    subrels[etype] = UnitGraph::CreateFromCOO(
      (src_vtype == dst_vtype)? 1 : 2,
107
108
109
110
111
      ret.induced_vertices[src_vtype]->shape[0],
      ret.induced_vertices[dst_vtype]->shape[0],
      subedges[etype].src,
      subedges[etype].dst);
  }
112
113
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      hg->meta_graph(), subrels, std::move(num_vertices_per_type)));
114
115
116
  return ret;
}

117
void HeteroGraphSanityCheck(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
118
119
120
  // Sanity check
  CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
  CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
Minjie Wang's avatar
Minjie Wang committed
121
  // all relation graphs must have only one edge type
122
  for (const auto &rg : rel_graphs) {
Minjie Wang's avatar
Minjie Wang committed
123
    CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
124
  }
125
126
127
128
  auto ctx = rel_graphs[0]->Context();
  for (const auto &rg : rel_graphs) {
    CHECK_EQ(rg->Context(), ctx) << "Each relation graph must have the same context.";
  }
129
130
131
132
}

std::vector<int64_t>
InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
133
  // create num verts per type
134
  std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);
135
136
137
138
139
140
141
142
143

  EdgeArray etype_array = meta_graph->Edges();
  dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data);
  dgl_type_t *dsttypes = static_cast<dgl_type_t *>(etype_array.dst->data);
  dgl_type_t *etypes = static_cast<dgl_type_t *>(etype_array.id->data);
  for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
    dgl_type_t srctype = srctypes[i];
    dgl_type_t dsttype = dsttypes[i];
    dgl_type_t etype = etypes[i];
Minjie Wang's avatar
Minjie Wang committed
144
145
146
    const auto& rg = rel_graphs[etype];
    const auto sty = 0;
    const auto dty = rg->NumVertexTypes() == 1? 0 : 1;
147
148
149
    size_t nv;

    // # nodes of source type
Minjie Wang's avatar
Minjie Wang committed
150
    nv = rg->NumVertices(sty);
151
152
    if (num_verts_per_type[srctype] < 0)
      num_verts_per_type[srctype] = nv;
153
    else
154
      CHECK_EQ(num_verts_per_type[srctype], nv)
155
156
        << "Mismatch number of vertices for vertex type " << srctype;
    // # nodes of destination type
Minjie Wang's avatar
Minjie Wang committed
157
    nv = rg->NumVertices(dty);
158
159
    if (num_verts_per_type[dsttype] < 0)
      num_verts_per_type[dsttype] = nv;
160
    else
161
      CHECK_EQ(num_verts_per_type[dsttype], nv)
162
        << "Mismatch number of vertices for vertex type " << dsttype;
163
  }
164
165
  return num_verts_per_type;
}
166

167
168
std::vector<UnitGraphPtr> CastToUnitGraphs(const std::vector<HeteroGraphPtr>& rel_graphs) {
  std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());
169
170
171
  for (size_t i = 0; i < rel_graphs.size(); ++i) {
    HeteroGraphPtr relg = rel_graphs[i];
    if (std::dynamic_pointer_cast<UnitGraph>(relg)) {
172
      relation_graphs[i] = std::dynamic_pointer_cast<UnitGraph>(relg);
173
    } else {
174
      relation_graphs[i] = CHECK_NOTNULL(
175
176
177
          std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0)));
    }
  }
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
  return relation_graphs;
}

}  // namespace

HeteroGraph::HeteroGraph(
    GraphPtr meta_graph,
    const std::vector<HeteroGraphPtr>& rel_graphs,
    const std::vector<int64_t>& num_nodes_per_type) : BaseHeteroGraph(meta_graph) {
  if (num_nodes_per_type.size() == 0)
    num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
  else
    num_verts_per_type_ = num_nodes_per_type;
  HeteroGraphSanityCheck(meta_graph, rel_graphs);
  relation_graphs_ = CastToUnitGraphs(rel_graphs);
193
194
195
}

bool HeteroGraph::IsMultigraph() const {
196
197
198
199
200
201
  for (const auto &hg : relation_graphs_) {
    if (hg->IsMultigraph()) {
      return true;
    }
  }
  return false;
202
203
204
}

BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
205
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
206
207
208
209
210
211
212
213
  return aten::LT(vids, NumVertices(vtype));
}

HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
  CHECK_EQ(vids.size(), NumVertexTypes())
    << "Invalid input: the input list size must be the same as the number of vertex types.";
  HeteroSubgraph ret;
  ret.induced_vertices = vids;
214
215
216
  std::vector<int64_t> num_vertices_per_type(NumVertexTypes());
  for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
    num_vertices_per_type[vtype] = vids[vtype]->shape[0];
217
218
219
220
221
222
  ret.induced_edges.resize(NumEdgeTypes());
  std::vector<HeteroGraphPtr> subrels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
    auto pair = meta_graph_->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
Minjie Wang's avatar
Minjie Wang committed
223
224
225
226
    const std::vector<IdArray> rel_vids = (src_vtype == dst_vtype) ?
      std::vector<IdArray>({vids[src_vtype]}) :
      std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
    const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
227
228
229
    subrels[etype] = rel_vsg.graph;
    ret.induced_edges[etype] = rel_vsg.induced_edges[0];
  }
230
231
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      meta_graph_, subrels, std::move(num_vertices_per_type)));
232
233
234
235
236
237
238
239
240
241
242
243
  return ret;
}

HeteroSubgraph HeteroGraph::EdgeSubgraph(
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
  if (preserve_nodes) {
    return EdgeSubgraphPreserveNodes(this, eids);
  } else {
    return EdgeSubgraphNoPreserveNodes(this, eids);
  }
}

244
245
246
247
248
249
250
251
252
253
254
HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
    rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));
  }
  return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
                                        hgindex->num_verts_per_type_));
}

255
256
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
                                   const DGLStreamHandle &stream) {
257
258
259
260
261
262
263
  if (ctx == g->Context()) {
    return g;
  }
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
264
    rel_graphs.push_back(UnitGraph::CopyTo(g, ctx, stream));
265
266
267
268
269
  }
  return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
                                        hgindex->num_verts_per_type_));
}

270
271
272
273
274
275
276
277
278
279
void HeteroGraph::PinMemory_() {
  for (auto g : relation_graphs_)
    g->PinMemory_();
}

void HeteroGraph::UnpinMemory_() {
  for (auto g : relation_graphs_)
    g->UnpinMemory_();
}

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
std::string HeteroGraph::SharedMemName() const {
  return shared_mem_ ? shared_mem_->GetName() : "";
}

HeteroGraphPtr HeteroGraph::CopyToSharedMem(
      HeteroGraphPtr g, const std::string& name, const std::vector<std::string>& ntypes,
      const std::vector<std::string>& etypes, const std::set<std::string>& fmts) {
  // TODO(JJ): Raise error when calling shared_memory if graph index is on gpu
  auto hg = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hg);
  if (hg->SharedMemName() == name)
    return g;

  // Copy buffer to share memory
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->CreateNew(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  bool has_coo = fmts.find("coo") != fmts.end();
  bool has_csr = fmts.find("csr") != fmts.end();
  bool has_csc = fmts.find("csc") != fmts.end();
  shm.Write(g->NumBits());
  shm.Write(has_coo);
  shm.Write(has_csr);
  shm.Write(has_csc);
  shm.Write(ImmutableGraph::ToImmutable(hg->meta_graph_));
  shm.Write(hg->num_verts_per_type_);

  std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());

  for (dgl_type_t etype = 0 ; etype < g->NumEdgeTypes() ; ++etype) {
312
313
    auto src_dst_type = g->GetEndpointTypes(etype);
    int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2);
314
315
316
317
318
319
320
321
322
323
324
325
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      coo = shm.CopyToSharedMem(hg->GetCOOMatrix(etype), prefix + "_coo");
    }
    if (has_csr) {
      csr = shm.CopyToSharedMem(hg->GetCSRMatrix(etype), prefix + "_csr");
    }
    if (has_csc) {
      csc = shm.CopyToSharedMem(hg->GetCSCMatrix(etype), prefix + "_csc");
    }
326
327
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
328
329
330
331
332
333
334
335
336
337
338
339
340
  }

  auto ret = std::shared_ptr<HeteroGraph>(
      new HeteroGraph(hg->meta_graph_, relgraphs, hg->num_verts_per_type_));
  ret->shared_mem_ = mem;

  shm.Write(ntypes);
  shm.Write(etypes);
  return ret;
}

std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
    HeteroGraph::CreateFromSharedMem(const std::string &name) {
341
342
343
344
  bool exist = SharedMemory::Exist(name);
  if (!exist) {
    return std::make_tuple(nullptr, std::vector<std::string>(), std::vector<std::string>());
  }
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  uint8_t nbits;
  CHECK(shm.Read(&nbits)) << "invalid nbits (unit8_t)";

  bool has_coo, has_csr, has_csc;
  CHECK(shm.Read(&has_coo)) << "invalid nbits (unit8_t)";
  CHECK(shm.Read(&has_csr)) << "invalid csr (unit8_t)";
  CHECK(shm.Read(&has_csc)) << "invalid csc (unit8_t)";

  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(shm.Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;

  std::vector<int64_t> num_verts_per_type;
  CHECK(shm.Read(&num_verts_per_type)) << "Invalid number of vertices per type";

  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  for (dgl_type_t etype = 0 ; etype < metagraph->NumEdges() ; ++etype) {
367
368
    auto src_dst = metagraph->FindEdge(etype);
    int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2;
369
370
371
372
373
374
375
376
377
378
379
380
381
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      shm.CreateFromSharedMem(&coo, prefix + "_coo");
    }
    if (has_csr) {
      shm.CreateFromSharedMem(&csr, prefix + "_csr");
    }
    if (has_csc) {
      shm.CreateFromSharedMem(&csc, prefix + "_csc");
    }

382
383
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
384
385
386
387
388
389
390
391
392
393
394
395
  }

  auto ret = std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
  ret->shared_mem_ = mem;

  std::vector<std::string> ntypes;
  std::vector<std::string> etypes;
  CHECK(shm.Read(&ntypes)) << "invalid ntypes";
  CHECK(shm.Read(&etypes)) << "invalid etypes";
  return std::make_tuple(ret, ntypes, etypes);
}

396
HeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const {
397
398
399
  std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
    auto relgraph = std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
400
    format_rels[etype] = relgraph->GetGraphInFormat(formats);
401
402
403
404
405
  }
  return HeteroGraphPtr(new HeteroGraph(
    meta_graph_, format_rels, NumVerticesPerType()));
}

406
407
408
409
410
FlattenedHeteroGraphPtr HeteroGraph::Flatten(
    const std::vector<dgl_type_t>& etypes) const {
  const int64_t bits = NumBits();
  if (bits == 32) {
    return FlattenImpl<int32_t>(etypes);
411
  } else {
412
413
414
415
416
417
    return FlattenImpl<int64_t>(etypes);
  }
}

template <class IdType>
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& etypes) const {
Minjie Wang's avatar
Minjie Wang committed
418
419
  std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
  size_t src_nodes = 0, dst_nodes = 0;
420
421
  std::vector<dgl_type_t> induced_srctype, induced_dsttype;
  std::vector<IdType> induced_srcid, induced_dstid;
Minjie Wang's avatar
Minjie Wang committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
  std::vector<dgl_type_t> srctype_set, dsttype_set;

  // XXXtype_offsets contain the mapping from node type and number of nodes after this
  // loop.
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t num_srctype_nodes = NumVertices(srctype);
    size_t num_dsttype_nodes = NumVertices(dsttype);

    if (srctype_offsets.count(srctype) == 0) {
      srctype_offsets[srctype] = num_srctype_nodes;
      srctype_set.push_back(srctype);
    }
    if (dsttype_offsets.count(dsttype) == 0) {
      dsttype_offsets[dsttype] = num_dsttype_nodes;
      dsttype_set.push_back(dsttype);
    }
  }
442
  // Sort the node types so that we can compare the sets and decide whether a homogeneous graph
Minjie Wang's avatar
Minjie Wang committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
  // should be returned.
  std::sort(srctype_set.begin(), srctype_set.end());
  std::sort(dsttype_set.begin(), dsttype_set.end());
  bool homograph = (srctype_set.size() == dsttype_set.size()) &&
    std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());

  // XXXtype_offsets contain the mapping from node type to node ID offsets after these
  // two loops.
  for (size_t i = 0; i < srctype_set.size(); ++i) {
    dgl_type_t ntype = srctype_set[i];
    size_t num_nodes = srctype_offsets[ntype];
    srctype_offsets[ntype] = src_nodes;
    src_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_srctype.push_back(ntype);
      induced_srcid.push_back(j);
    }
  }
  for (size_t i = 0; i < dsttype_set.size(); ++i) {
    dgl_type_t ntype = dsttype_set[i];
    size_t num_nodes = dsttype_offsets[ntype];
    dsttype_offsets[ntype] = dst_nodes;
    dst_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_dsttype.push_back(ntype);
      induced_dstid.push_back(j);
    }
  }
471

472
473
474
475
476
477
478
  // TODO(minjie): Using concat operations cause many fragmented memory.
  //   Need to optimize it in the future.
  std::vector<IdArray> src_arrs, dst_arrs, eid_arrs, induced_etypes;
  src_arrs.reserve(etypes.size());
  dst_arrs.reserve(etypes.size());
  eid_arrs.reserve(etypes.size());
  induced_etypes.reserve(etypes.size());
Minjie Wang's avatar
Minjie Wang committed
479
480
481
482
483
484
485
486
487
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t srctype_offset = srctype_offsets[srctype];
    size_t dsttype_offset = dsttype_offsets[dsttype];

    EdgeArray edges = Edges(etype);
    size_t num_edges = NumEdges(etype);
488
489
490
491
    src_arrs.push_back(edges.src + srctype_offset);
    dst_arrs.push_back(edges.dst + dsttype_offset);
    eid_arrs.push_back(edges.id);
    induced_etypes.push_back(aten::Full(etype, num_edges, NumBits(), Context()));
Minjie Wang's avatar
Minjie Wang committed
492
493
494
495
496
497
  }

  HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
      homograph ? 1 : 2,
      src_nodes,
      dst_nodes,
498
499
500
501
502
503
      aten::Concat(src_arrs),
      aten::Concat(dst_arrs));

  // Sanity check
  CHECK_EQ(gptr->Context(), Context());
  CHECK_EQ(gptr->NumBits(), NumBits());
Minjie Wang's avatar
Minjie Wang committed
504
505

  FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
506
  result->graph = HeteroGraphRef(HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));
507
508
509
510
511
512
513
514
515
  result->induced_srctype = aten::VecToIdArray(induced_srctype).CopyTo(Context());
  result->induced_srctype_set = aten::VecToIdArray(srctype_set).CopyTo(Context());
  result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());
  result->induced_etype = aten::Concat(induced_etypes);
  result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context());
  result->induced_eid = aten::Concat(eid_arrs);
  result->induced_dsttype = aten::VecToIdArray(induced_dsttype).CopyTo(Context());
  result->induced_dsttype_set = aten::VecToIdArray(dsttype_set).CopyTo(Context());
  result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context());
Minjie Wang's avatar
Minjie Wang committed
516
  return FlattenedHeteroGraphPtr(result);
517
518
}

519
520
521
522
523
524
constexpr uint64_t kDGLSerialize_HeteroGraph = 0xDD589FBE35224ABF;

bool HeteroGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
525
526
527
528
529
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
  meta_graph_ = meta_imgraph;
  CHECK(fs->Read(&relation_graphs_)) << "Invalid relation_graphs_";
  CHECK(fs->Read(&num_verts_per_type_)) << "Invalid num_verts_per_type_";
530
531
532
533
534
535
  return true;
}

void HeteroGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_HeteroGraph);
  auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
536
537
538
  fs->Write(meta_graph_ptr);
  fs->Write(relation_graphs_);
  fs->Write(num_verts_per_type_);
539
540
}

541
542
543
544
545
546
547
548
GraphPtr HeteroGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "graph has more than one node types";
  CHECK(NumEdgeTypes() == 1) << "graph has more than one edge types";
  auto unit_graph = CHECK_NOTNULL(
      std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));
  return unit_graph->AsImmutableGraph();
}

549
550
551
552
553
554
555
556
557
558
559
560
HeteroGraphPtr HeteroGraph::LineGraph(bool backtracking) const {
  CHECK_EQ(1, meta_graph_->NumEdges()) << "Only support Homogeneous graph now (one edge type)";
  CHECK_EQ(1, meta_graph_->NumVertices()) << "Only support Homogeneous graph now (one node type)";
  CHECK_EQ(1, relation_graphs_.size()) << "Only support Homogeneous graph now";
  UnitGraphPtr ug = relation_graphs_[0];

  const auto &ulg = ug->LineGraph(backtracking);
  std::vector<HeteroGraphPtr> rel_graph = {ulg};
  std::vector<int64_t> num_nodes_per_type = {static_cast<int64_t>(ulg->NumVertices(0))};
  return HeteroGraphPtr(new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type)));
}

561
}  // namespace dgl