heterograph.cc 21.7 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file graph/heterograph.cc
 * @brief Heterograph implementation
5
6
 */
#include "./heterograph.h"
7

Minjie Wang's avatar
Minjie Wang committed
8
#include <dgl/array.h>
9
#include <dgl/graph_serializer.h>
10
#include <dgl/immutable_graph.h>
11
#include <dmlc/memory_io.h>
12

13
#include <memory>
14
15
#include <tuple>
#include <utility>
16
#include <vector>
17
18
19
20
21
22

using namespace dgl::runtime;

namespace dgl {
namespace {

23
24
using dgl::ImmutableGraph;

25
26
27
HeteroSubgraph EdgeSubgraphPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
28
29
      << "Invalid input: the input list size must be the same as the number of "
         "edge type.";
30
31
32
33
34
35
36
37
38
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
  // When preserve_nodes is true, simply compute EdgeSubgraph for each bipartite
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
39
40
    const auto& rel_vsg =
        hg->GetRelationGraph(etype)->EdgeSubgraph({eids[etype]}, true);
41
42
43
44
    subrels[etype] = rel_vsg.graph;
    ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];
    ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];
  }
45
46
  ret.graph = HeteroGraphPtr(
      new HeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType()));
47
48
49
50
51
  return ret;
}

HeteroSubgraph EdgeSubgraphNoPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
52
53
  // TODO(minjie): In general, all relabeling should be separated with subgraph
  //   operations.
54
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
55
56
      << "Invalid input: the input list size must be the same as the number of "
         "edge type.";
57
58
59
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
60
61
62
63
  // NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite
  // complicated in heterograph. This is because we need to make sure bipartite
  // graphs that incident on the same vertex type must have the same ID space.
  // For example, suppose we have following heterograph:
64
65
  //
  // Meta graph: A -> B -> C
Minjie Wang's avatar
Minjie Wang committed
66
  // UnitGraph graphs:
67
68
69
  // * A -> B: (0, 0), (0, 1)
  // * B -> C: (1, 0), (1, 1)
  //
70
71
72
73
  // Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1,
  // 0). We need to make sure that in the result subgraph, node type B still has
  // two nodes. This means we cannot simply compute EdgeSubgraph for B->C which
  // will relabel node#1 of type B to be node #0.
74
75
76
  //
  // One implementation is as follows:
  // (1) For each bipartite graph, slice out the edges using the given eids.
77
78
79
80
  // (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the
  // vertex type
  //     and the value is the incident nodes from the bipartite graphs that has
  //     the vertex type as either srctype or dsttype.
81
82
  // (3) Then for each vertex type, use aten::Relabel_ on its vector<IdArray>.
  //     aten::Relabel_ computes the union of the vertex sets and relabel
83
84
  //     the unique elements from zero. The returned mapping array is the final
  //     induced vertex set for that vertex type.
85
86
87
88
89
90
91
92
93
94
95
96
97
98
  // (4) Use the relabeled edges to construct the bipartite graph.
  // step (1) & (2)
  std::vector<EdgeArray> subedges(hg->NumEdgeTypes());
  std::vector<std::vector<IdArray>> vtype2incnodes(hg->NumVertexTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    auto earray = hg->GetRelationGraph(etype)->FindEdges(0, eids[etype]);
    vtype2incnodes[src_vtype].push_back(earray.src);
    vtype2incnodes[dst_vtype].push_back(earray.dst);
    subedges[etype] = earray;
  }
  // step (3)
99
  std::vector<int64_t> num_vertices_per_type(hg->NumVertexTypes());
100
101
  for (dgl_type_t vtype = 0; vtype < hg->NumVertexTypes(); ++vtype) {
    ret.induced_vertices[vtype] = aten::Relabel_(vtype2incnodes[vtype]);
102
    num_vertices_per_type[vtype] = ret.induced_vertices[vtype]->shape[0];
103
104
105
106
107
108
109
  }
  // step (4)
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
Minjie Wang's avatar
Minjie Wang committed
110
    subrels[etype] = UnitGraph::CreateFromCOO(
111
112
113
114
        (src_vtype == dst_vtype) ? 1 : 2,
        ret.induced_vertices[src_vtype]->shape[0],
        ret.induced_vertices[dst_vtype]->shape[0], subedges[etype].src,
        subedges[etype].dst);
115
  }
116
117
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      hg->meta_graph(), subrels, std::move(num_vertices_per_type)));
118
119
120
  return ret;
}

121
122
void HeteroGraphSanityCheck(
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
123
124
125
  // Sanity check
  CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
  CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
Minjie Wang's avatar
Minjie Wang committed
126
  // all relation graphs must have only one edge type
127
128
129
  for (const auto& rg : rel_graphs) {
    CHECK_EQ(rg->NumEdgeTypes(), 1)
        << "Each relation graph must have only one edge type.";
130
  }
131
  auto ctx = rel_graphs[0]->Context();
132
133
134
  for (const auto& rg : rel_graphs) {
    CHECK_EQ(rg->Context(), ctx)
        << "Each relation graph must have the same context.";
135
  }
136
137
}

138
139
std::vector<int64_t> InferNumVerticesPerType(
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
140
  // create num verts per type
141
  std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);
142
143

  EdgeArray etype_array = meta_graph->Edges();
144
145
146
  dgl_type_t* srctypes = static_cast<dgl_type_t*>(etype_array.src->data);
  dgl_type_t* dsttypes = static_cast<dgl_type_t*>(etype_array.dst->data);
  dgl_type_t* etypes = static_cast<dgl_type_t*>(etype_array.id->data);
147
148
149
150
  for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
    dgl_type_t srctype = srctypes[i];
    dgl_type_t dsttype = dsttypes[i];
    dgl_type_t etype = etypes[i];
Minjie Wang's avatar
Minjie Wang committed
151
152
    const auto& rg = rel_graphs[etype];
    const auto sty = 0;
153
    const auto dty = rg->NumVertexTypes() == 1 ? 0 : 1;
154
155
156
    size_t nv;

    // # nodes of source type
Minjie Wang's avatar
Minjie Wang committed
157
    nv = rg->NumVertices(sty);
158
159
    if (num_verts_per_type[srctype] < 0)
      num_verts_per_type[srctype] = nv;
160
    else
161
      CHECK_EQ(num_verts_per_type[srctype], nv)
162
          << "Mismatch number of vertices for vertex type " << srctype;
163
    // # nodes of destination type
Minjie Wang's avatar
Minjie Wang committed
164
    nv = rg->NumVertices(dty);
165
166
    if (num_verts_per_type[dsttype] < 0)
      num_verts_per_type[dsttype] = nv;
167
    else
168
      CHECK_EQ(num_verts_per_type[dsttype], nv)
169
          << "Mismatch number of vertices for vertex type " << dsttype;
170
  }
171
172
  return num_verts_per_type;
}
173

174
175
std::vector<UnitGraphPtr> CastToUnitGraphs(
    const std::vector<HeteroGraphPtr>& rel_graphs) {
176
  std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());
177
178
179
  for (size_t i = 0; i < rel_graphs.size(); ++i) {
    HeteroGraphPtr relg = rel_graphs[i];
    if (std::dynamic_pointer_cast<UnitGraph>(relg)) {
180
      relation_graphs[i] = std::dynamic_pointer_cast<UnitGraph>(relg);
181
    } else {
182
      relation_graphs[i] = CHECK_NOTNULL(
183
184
185
          std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0)));
    }
  }
186
187
188
189
190
191
  return relation_graphs;
}

}  // namespace

HeteroGraph::HeteroGraph(
192
193
194
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
    const std::vector<int64_t>& num_nodes_per_type)
    : BaseHeteroGraph(meta_graph) {
195
196
197
198
199
200
  if (num_nodes_per_type.size() == 0)
    num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
  else
    num_verts_per_type_ = num_nodes_per_type;
  HeteroGraphSanityCheck(meta_graph, rel_graphs);
  relation_graphs_ = CastToUnitGraphs(rel_graphs);
201
202
203
}

bool HeteroGraph::IsMultigraph() const {
204
  for (const auto& hg : relation_graphs_) {
205
206
207
208
209
    if (hg->IsMultigraph()) {
      return true;
    }
  }
  return false;
210
211
212
}

BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
213
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
214
215
216
  return aten::LT(vids, NumVertices(vtype));
}

217
218
HeteroSubgraph HeteroGraph::VertexSubgraph(
    const std::vector<IdArray>& vids) const {
219
  CHECK_EQ(vids.size(), NumVertexTypes())
220
221
      << "Invalid input: the input list size must be the same as the number of "
         "vertex types.";
222
223
  HeteroSubgraph ret;
  ret.induced_vertices = vids;
224
225
226
  std::vector<int64_t> num_vertices_per_type(NumVertexTypes());
  for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
    num_vertices_per_type[vtype] = vids[vtype]->shape[0];
227
228
229
230
231
232
  ret.induced_edges.resize(NumEdgeTypes());
  std::vector<HeteroGraphPtr> subrels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
    auto pair = meta_graph_->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
233
234
235
236
    const std::vector<IdArray> rel_vids =
        (src_vtype == dst_vtype)
            ? std::vector<IdArray>({vids[src_vtype]})
            : std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
Minjie Wang's avatar
Minjie Wang committed
237
    const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
238
239
240
    subrels[etype] = rel_vsg.graph;
    ret.induced_edges[etype] = rel_vsg.induced_edges[0];
  }
241
242
  ret.graph = HeteroGraphPtr(
      new HeteroGraph(meta_graph_, subrels, std::move(num_vertices_per_type)));
243
244
245
246
247
248
249
250
251
252
253
254
  return ret;
}

HeteroSubgraph HeteroGraph::EdgeSubgraph(
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
  if (preserve_nodes) {
    return EdgeSubgraphPreserveNodes(this, eids);
  } else {
    return EdgeSubgraphNoPreserveNodes(this, eids);
  }
}

255
256
257
258
259
260
261
HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
    rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));
  }
262
263
  return HeteroGraphPtr(new HeteroGraph(
      hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));
264
265
}

266
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {
267
268
269
270
271
272
273
  if (ctx == g->Context()) {
    return g;
  }
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
274
    rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
275
  }
276
277
  return HeteroGraphPtr(new HeteroGraph(
      hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));
278
279
}

280
void HeteroGraph::PinMemory_() {
281
  for (auto g : relation_graphs_) g->PinMemory_();
282
283
284
}

void HeteroGraph::UnpinMemory_() {
285
  for (auto g : relation_graphs_) g->UnpinMemory_();
286
287
}

288
void HeteroGraph::RecordStream(DGLStreamHandle stream) {
289
  for (auto g : relation_graphs_) g->RecordStream(stream);
290
291
}

292
293
294
295
296
std::string HeteroGraph::SharedMemName() const {
  return shared_mem_ ? shared_mem_->GetName() : "";
}

HeteroGraphPtr HeteroGraph::CopyToSharedMem(
297
298
299
    HeteroGraphPtr g, const std::string& name,
    const std::vector<std::string>& ntypes,
    const std::vector<std::string>& etypes, const std::set<std::string>& fmts) {
300
301
302
  // TODO(JJ): Raise error when calling shared_memory if graph index is on gpu
  auto hg = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hg);
303
  if (hg->SharedMemName() == name) return g;
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

  // Copy buffer to share memory
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->CreateNew(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  bool has_coo = fmts.find("coo") != fmts.end();
  bool has_csr = fmts.find("csr") != fmts.end();
  bool has_csc = fmts.find("csc") != fmts.end();
  shm.Write(g->NumBits());
  shm.Write(has_coo);
  shm.Write(has_csr);
  shm.Write(has_csc);
  shm.Write(ImmutableGraph::ToImmutable(hg->meta_graph_));
  shm.Write(hg->num_verts_per_type_);

  std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());

323
  for (dgl_type_t etype = 0; etype < g->NumEdgeTypes(); ++etype) {
324
325
    auto src_dst_type = g->GetEndpointTypes(etype);
    int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2);
326
327
328
329
330
331
332
333
334
335
336
337
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      coo = shm.CopyToSharedMem(hg->GetCOOMatrix(etype), prefix + "_coo");
    }
    if (has_csr) {
      csr = shm.CopyToSharedMem(hg->GetCSRMatrix(etype), prefix + "_csr");
    }
    if (has_csc) {
      csc = shm.CopyToSharedMem(hg->GetCSCMatrix(etype), prefix + "_csc");
    }
338
339
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
340
341
342
343
344
345
346
347
348
349
350
351
  }

  auto ret = std::shared_ptr<HeteroGraph>(
      new HeteroGraph(hg->meta_graph_, relgraphs, hg->num_verts_per_type_));
  ret->shared_mem_ = mem;

  shm.Write(ntypes);
  shm.Write(etypes);
  return ret;
}

std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
352
HeteroGraph::CreateFromSharedMem(const std::string& name) {
353
354
  bool exist = SharedMemory::Exist(name);
  if (!exist) {
355
356
    return std::make_tuple(
        nullptr, std::vector<std::string>(), std::vector<std::string>());
357
  }
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  uint8_t nbits;
  CHECK(shm.Read(&nbits)) << "invalid nbits (unit8_t)";

  bool has_coo, has_csr, has_csc;
  CHECK(shm.Read(&has_coo)) << "invalid nbits (unit8_t)";
  CHECK(shm.Read(&has_csr)) << "invalid csr (unit8_t)";
  CHECK(shm.Read(&has_csc)) << "invalid csc (unit8_t)";

  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(shm.Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;

  std::vector<int64_t> num_verts_per_type;
  CHECK(shm.Read(&num_verts_per_type)) << "Invalid number of vertices per type";

  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
379
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
380
381
    auto src_dst = metagraph->FindEdge(etype);
    int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2;
382
383
384
385
386
387
388
389
390
391
392
393
394
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      shm.CreateFromSharedMem(&coo, prefix + "_coo");
    }
    if (has_csr) {
      shm.CreateFromSharedMem(&csr, prefix + "_csr");
    }
    if (has_csc) {
      shm.CreateFromSharedMem(&csc, prefix + "_csc");
    }

395
396
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
397
398
  }

399
400
  auto ret =
      std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
401
402
403
404
405
406
407
408
409
  ret->shared_mem_ = mem;

  std::vector<std::string> ntypes;
  std::vector<std::string> etypes;
  CHECK(shm.Read(&ntypes)) << "invalid ntypes";
  CHECK(shm.Read(&etypes)) << "invalid etypes";
  return std::make_tuple(ret, ntypes, etypes);
}

410
HeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const {
411
412
  std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
413
414
    auto relgraph =
        std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
415
    format_rels[etype] = relgraph->GetGraphInFormat(formats);
416
  }
417
418
  return HeteroGraphPtr(
      new HeteroGraph(meta_graph_, format_rels, NumVerticesPerType()));
419
420
}

421
422
423
424
425
FlattenedHeteroGraphPtr HeteroGraph::Flatten(
    const std::vector<dgl_type_t>& etypes) const {
  const int64_t bits = NumBits();
  if (bits == 32) {
    return FlattenImpl<int32_t>(etypes);
426
  } else {
427
428
429
430
431
    return FlattenImpl<int64_t>(etypes);
  }
}

template <class IdType>
432
433
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(
    const std::vector<dgl_type_t>& etypes) const {
Minjie Wang's avatar
Minjie Wang committed
434
435
  std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
  size_t src_nodes = 0, dst_nodes = 0;
436
437
  std::vector<dgl_type_t> induced_srctype, induced_dsttype;
  std::vector<IdType> induced_srcid, induced_dstid;
Minjie Wang's avatar
Minjie Wang committed
438
439
  std::vector<dgl_type_t> srctype_set, dsttype_set;

440
441
  // XXXtype_offsets contain the mapping from node type and number of nodes
  // after this loop.
Minjie Wang's avatar
Minjie Wang committed
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t num_srctype_nodes = NumVertices(srctype);
    size_t num_dsttype_nodes = NumVertices(dsttype);

    if (srctype_offsets.count(srctype) == 0) {
      srctype_offsets[srctype] = num_srctype_nodes;
      srctype_set.push_back(srctype);
    }
    if (dsttype_offsets.count(dsttype) == 0) {
      dsttype_offsets[dsttype] = num_dsttype_nodes;
      dsttype_set.push_back(dsttype);
    }
  }
458
459
  // Sort the node types so that we can compare the sets and decide whether a
  // homogeneous graph should be returned.
Minjie Wang's avatar
Minjie Wang committed
460
461
  std::sort(srctype_set.begin(), srctype_set.end());
  std::sort(dsttype_set.begin(), dsttype_set.end());
462
463
464
  bool homograph =
      (srctype_set.size() == dsttype_set.size()) &&
      std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());
Minjie Wang's avatar
Minjie Wang committed
465

466
467
  // XXXtype_offsets contain the mapping from node type to node ID offsets after
  // these two loops.
Minjie Wang's avatar
Minjie Wang committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
  for (size_t i = 0; i < srctype_set.size(); ++i) {
    dgl_type_t ntype = srctype_set[i];
    size_t num_nodes = srctype_offsets[ntype];
    srctype_offsets[ntype] = src_nodes;
    src_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_srctype.push_back(ntype);
      induced_srcid.push_back(j);
    }
  }
  for (size_t i = 0; i < dsttype_set.size(); ++i) {
    dgl_type_t ntype = dsttype_set[i];
    size_t num_nodes = dsttype_offsets[ntype];
    dsttype_offsets[ntype] = dst_nodes;
    dst_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_dsttype.push_back(ntype);
      induced_dstid.push_back(j);
    }
  }
488

489
490
491
492
493
494
495
  // TODO(minjie): Using concat operations cause many fragmented memory.
  //   Need to optimize it in the future.
  std::vector<IdArray> src_arrs, dst_arrs, eid_arrs, induced_etypes;
  src_arrs.reserve(etypes.size());
  dst_arrs.reserve(etypes.size());
  eid_arrs.reserve(etypes.size());
  induced_etypes.reserve(etypes.size());
Minjie Wang's avatar
Minjie Wang committed
496
497
498
499
500
501
502
503
504
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t srctype_offset = srctype_offsets[srctype];
    size_t dsttype_offset = dsttype_offsets[dsttype];

    EdgeArray edges = Edges(etype);
    size_t num_edges = NumEdges(etype);
505
506
507
    src_arrs.push_back(edges.src + srctype_offset);
    dst_arrs.push_back(edges.dst + dsttype_offset);
    eid_arrs.push_back(edges.id);
508
509
    induced_etypes.push_back(
        aten::Full(etype, num_edges, NumBits(), Context()));
Minjie Wang's avatar
Minjie Wang committed
510
511
512
  }

  HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
513
      homograph ? 1 : 2, src_nodes, dst_nodes, aten::Concat(src_arrs),
514
515
516
517
518
      aten::Concat(dst_arrs));

  // Sanity check
  CHECK_EQ(gptr->Context(), Context());
  CHECK_EQ(gptr->NumBits(), NumBits());
Minjie Wang's avatar
Minjie Wang committed
519
520

  FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
521
522
523
524
525
526
  result->graph = HeteroGraphRef(
      HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));
  result->induced_srctype =
      aten::VecToIdArray(induced_srctype).CopyTo(Context());
  result->induced_srctype_set =
      aten::VecToIdArray(srctype_set).CopyTo(Context());
527
528
529
530
  result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());
  result->induced_etype = aten::Concat(induced_etypes);
  result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context());
  result->induced_eid = aten::Concat(eid_arrs);
531
532
533
534
  result->induced_dsttype =
      aten::VecToIdArray(induced_dsttype).CopyTo(Context());
  result->induced_dsttype_set =
      aten::VecToIdArray(dsttype_set).CopyTo(Context());
535
  result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context());
Minjie Wang's avatar
Minjie Wang committed
536
  return FlattenedHeteroGraphPtr(result);
537
538
}

539
540
541
542
543
544
constexpr uint64_t kDGLSerialize_HeteroGraph = 0xDD589FBE35224ABF;

bool HeteroGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
545
546
547
548
549
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
  meta_graph_ = meta_imgraph;
  CHECK(fs->Read(&relation_graphs_)) << "Invalid relation_graphs_";
  CHECK(fs->Read(&num_verts_per_type_)) << "Invalid num_verts_per_type_";
550
551
552
553
554
555
  return true;
}

void HeteroGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_HeteroGraph);
  auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
556
557
558
  fs->Write(meta_graph_ptr);
  fs->Write(relation_graphs_);
  fs->Write(num_verts_per_type_);
559
560
}

561
562
563
GraphPtr HeteroGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "graph has more than one node types";
  CHECK(NumEdgeTypes() == 1) << "graph has more than one edge types";
564
565
  auto unit_graph =
      CHECK_NOTNULL(std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));
566
567
568
  return unit_graph->AsImmutableGraph();
}

569
HeteroGraphPtr HeteroGraph::LineGraph(bool backtracking) const {
570
571
572
573
  CHECK_EQ(1, meta_graph_->NumEdges())
      << "Only support Homogeneous graph now (one edge type)";
  CHECK_EQ(1, meta_graph_->NumVertices())
      << "Only support Homogeneous graph now (one node type)";
574
575
576
  CHECK_EQ(1, relation_graphs_.size()) << "Only support Homogeneous graph now";
  UnitGraphPtr ug = relation_graphs_[0];

577
  const auto& ulg = ug->LineGraph(backtracking);
578
  std::vector<HeteroGraphPtr> rel_graph = {ulg};
579
580
581
582
  std::vector<int64_t> num_nodes_per_type = {
      static_cast<int64_t>(ulg->NumVertices(0))};
  return HeteroGraphPtr(
      new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type)));
583
584
}

585
}  // namespace dgl