"vscode:/vscode.git/clone" did not exist on "dd7c9ebeaf09cf981fe3688bc454c197cb4eb635"
heterograph.cc 22.6 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file graph/heterograph.cc
 * @brief Heterograph implementation
5
6
 */
#include "./heterograph.h"
7

Minjie Wang's avatar
Minjie Wang committed
8
#include <dgl/array.h>
9
#include <dgl/graph_serializer.h>
10
#include <dgl/immutable_graph.h>
11
#include <dmlc/memory_io.h>
12

13
#include <memory>
14
15
#include <tuple>
#include <utility>
16
#include <vector>
17
18
19
20
21
22

using namespace dgl::runtime;

namespace dgl {
namespace {

23
24
using dgl::ImmutableGraph;

25
26
27
HeteroSubgraph EdgeSubgraphPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
28
29
      << "Invalid input: the input list size must be the same as the number of "
         "edge type.";
30
31
32
33
34
35
36
37
38
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
  // When preserve_nodes is true, simply compute EdgeSubgraph for each bipartite
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
39
40
    const auto& rel_vsg =
        hg->GetRelationGraph(etype)->EdgeSubgraph({eids[etype]}, true);
41
42
43
44
    subrels[etype] = rel_vsg.graph;
    ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];
    ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];
  }
45
46
  ret.graph = HeteroGraphPtr(
      new HeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType()));
47
48
49
50
51
  return ret;
}

HeteroSubgraph EdgeSubgraphNoPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
52
53
  // TODO(minjie): In general, all relabeling should be separated with subgraph
  //   operations.
54
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
55
56
      << "Invalid input: the input list size must be the same as the number of "
         "edge type.";
57
58
59
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
60
61
62
63
  // NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite
  // complicated in heterograph. This is because we need to make sure bipartite
  // graphs that incident on the same vertex type must have the same ID space.
  // For example, suppose we have following heterograph:
64
65
  //
  // Meta graph: A -> B -> C
Minjie Wang's avatar
Minjie Wang committed
66
  // UnitGraph graphs:
67
68
69
  // * A -> B: (0, 0), (0, 1)
  // * B -> C: (1, 0), (1, 1)
  //
70
71
72
73
  // Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1,
  // 0). We need to make sure that in the result subgraph, node type B still has
  // two nodes. This means we cannot simply compute EdgeSubgraph for B->C which
  // will relabel node#1 of type B to be node #0.
74
75
76
  //
  // One implementation is as follows:
  // (1) For each bipartite graph, slice out the edges using the given eids.
77
78
79
80
  // (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the
  // vertex type
  //     and the value is the incident nodes from the bipartite graphs that has
  //     the vertex type as either srctype or dsttype.
81
82
  // (3) Then for each vertex type, use aten::Relabel_ on its vector<IdArray>.
  //     aten::Relabel_ computes the union of the vertex sets and relabel
83
84
  //     the unique elements from zero. The returned mapping array is the final
  //     induced vertex set for that vertex type.
85
86
87
88
89
90
91
92
93
94
95
96
97
98
  // (4) Use the relabeled edges to construct the bipartite graph.
  // step (1) & (2)
  std::vector<EdgeArray> subedges(hg->NumEdgeTypes());
  std::vector<std::vector<IdArray>> vtype2incnodes(hg->NumVertexTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    auto earray = hg->GetRelationGraph(etype)->FindEdges(0, eids[etype]);
    vtype2incnodes[src_vtype].push_back(earray.src);
    vtype2incnodes[dst_vtype].push_back(earray.dst);
    subedges[etype] = earray;
  }
  // step (3)
99
  std::vector<int64_t> num_vertices_per_type(hg->NumVertexTypes());
100
101
  for (dgl_type_t vtype = 0; vtype < hg->NumVertexTypes(); ++vtype) {
    ret.induced_vertices[vtype] = aten::Relabel_(vtype2incnodes[vtype]);
102
    num_vertices_per_type[vtype] = ret.induced_vertices[vtype]->shape[0];
103
104
105
106
107
108
109
  }
  // step (4)
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
Minjie Wang's avatar
Minjie Wang committed
110
    subrels[etype] = UnitGraph::CreateFromCOO(
111
112
113
114
        (src_vtype == dst_vtype) ? 1 : 2,
        ret.induced_vertices[src_vtype]->shape[0],
        ret.induced_vertices[dst_vtype]->shape[0], subedges[etype].src,
        subedges[etype].dst);
115
  }
116
117
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      hg->meta_graph(), subrels, std::move(num_vertices_per_type)));
118
119
120
  return ret;
}

121
122
void HeteroGraphSanityCheck(
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
123
124
125
  // Sanity check
  CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
  CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
Minjie Wang's avatar
Minjie Wang committed
126
  // all relation graphs must have only one edge type
127
128
129
  for (const auto& rg : rel_graphs) {
    CHECK_EQ(rg->NumEdgeTypes(), 1)
        << "Each relation graph must have only one edge type.";
130
  }
131
  auto ctx = rel_graphs[0]->Context();
132
133
134
  for (const auto& rg : rel_graphs) {
    CHECK_EQ(rg->Context(), ctx)
        << "Each relation graph must have the same context.";
135
  }
136
137
}

138
139
std::vector<int64_t> InferNumVerticesPerType(
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
140
  // create num verts per type
141
  std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);
142
143

  EdgeArray etype_array = meta_graph->Edges();
144
145
146
  dgl_type_t* srctypes = static_cast<dgl_type_t*>(etype_array.src->data);
  dgl_type_t* dsttypes = static_cast<dgl_type_t*>(etype_array.dst->data);
  dgl_type_t* etypes = static_cast<dgl_type_t*>(etype_array.id->data);
147
148
149
150
  for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
    dgl_type_t srctype = srctypes[i];
    dgl_type_t dsttype = dsttypes[i];
    dgl_type_t etype = etypes[i];
Minjie Wang's avatar
Minjie Wang committed
151
152
    const auto& rg = rel_graphs[etype];
    const auto sty = 0;
153
    const auto dty = rg->NumVertexTypes() == 1 ? 0 : 1;
154
155
156
    size_t nv;

    // # nodes of source type
Minjie Wang's avatar
Minjie Wang committed
157
    nv = rg->NumVertices(sty);
158
159
    if (num_verts_per_type[srctype] < 0)
      num_verts_per_type[srctype] = nv;
160
    else
161
      CHECK_EQ(num_verts_per_type[srctype], nv)
162
          << "Mismatch number of vertices for vertex type " << srctype;
163
    // # nodes of destination type
Minjie Wang's avatar
Minjie Wang committed
164
    nv = rg->NumVertices(dty);
165
166
    if (num_verts_per_type[dsttype] < 0)
      num_verts_per_type[dsttype] = nv;
167
    else
168
      CHECK_EQ(num_verts_per_type[dsttype], nv)
169
          << "Mismatch number of vertices for vertex type " << dsttype;
170
  }
171
172
  return num_verts_per_type;
}
173

174
175
std::vector<UnitGraphPtr> CastToUnitGraphs(
    const std::vector<HeteroGraphPtr>& rel_graphs) {
176
  std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());
177
178
179
  for (size_t i = 0; i < rel_graphs.size(); ++i) {
    HeteroGraphPtr relg = rel_graphs[i];
    if (std::dynamic_pointer_cast<UnitGraph>(relg)) {
180
      relation_graphs[i] = std::dynamic_pointer_cast<UnitGraph>(relg);
181
    } else {
182
      relation_graphs[i] = CHECK_NOTNULL(
183
184
185
          std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0)));
    }
  }
186
187
188
189
190
191
  return relation_graphs;
}

}  // namespace

HeteroGraph::HeteroGraph(
192
193
194
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
    const std::vector<int64_t>& num_nodes_per_type)
    : BaseHeteroGraph(meta_graph) {
195
196
197
198
199
200
  if (num_nodes_per_type.size() == 0)
    num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
  else
    num_verts_per_type_ = num_nodes_per_type;
  HeteroGraphSanityCheck(meta_graph, rel_graphs);
  relation_graphs_ = CastToUnitGraphs(rel_graphs);
201
202
203
}

bool HeteroGraph::IsMultigraph() const {
204
  for (const auto& hg : relation_graphs_) {
205
206
207
208
209
    if (hg->IsMultigraph()) {
      return true;
    }
  }
  return false;
210
211
212
}

BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
213
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
214
215
216
  return aten::LT(vids, NumVertices(vtype));
}

217
218
HeteroSubgraph HeteroGraph::VertexSubgraph(
    const std::vector<IdArray>& vids) const {
219
  CHECK_EQ(vids.size(), NumVertexTypes())
220
221
      << "Invalid input: the input list size must be the same as the number of "
         "vertex types.";
222
223
  HeteroSubgraph ret;
  ret.induced_vertices = vids;
224
225
226
  std::vector<int64_t> num_vertices_per_type(NumVertexTypes());
  for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
    num_vertices_per_type[vtype] = vids[vtype]->shape[0];
227
228
229
230
231
232
  ret.induced_edges.resize(NumEdgeTypes());
  std::vector<HeteroGraphPtr> subrels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
    auto pair = meta_graph_->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
233
234
235
236
    const std::vector<IdArray> rel_vids =
        (src_vtype == dst_vtype)
            ? std::vector<IdArray>({vids[src_vtype]})
            : std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
Minjie Wang's avatar
Minjie Wang committed
237
    const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
238
239
240
    subrels[etype] = rel_vsg.graph;
    ret.induced_edges[etype] = rel_vsg.induced_edges[0];
  }
241
242
  ret.graph = HeteroGraphPtr(
      new HeteroGraph(meta_graph_, subrels, std::move(num_vertices_per_type)));
243
244
245
246
247
248
249
250
251
252
253
254
  return ret;
}

HeteroSubgraph HeteroGraph::EdgeSubgraph(
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
  if (preserve_nodes) {
    return EdgeSubgraphPreserveNodes(this, eids);
  } else {
    return EdgeSubgraphNoPreserveNodes(this, eids);
  }
}

255
256
257
258
259
260
261
HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
    rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));
  }
262
263
  return HeteroGraphPtr(new HeteroGraph(
      hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));
264
265
}

266
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {
267
268
269
270
271
272
273
  if (ctx == g->Context()) {
    return g;
  }
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
274
    rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
275
  }
276
277
  return HeteroGraphPtr(new HeteroGraph(
      hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));
278
279
}

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
HeteroGraphPtr HeteroGraph::PinMemory(HeteroGraphPtr g) {
  auto casted_ptr = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(casted_ptr);
  auto relation_graphs = casted_ptr->relation_graphs_;

  auto it = std::find_if_not(
      relation_graphs.begin(), relation_graphs.end(),
      [](auto& underlying_g) { return underlying_g->IsPinned(); });
  // All underlying relation graphs are pinned, return the input hetero-graph
  // directly.
  if (it == relation_graphs.end()) return g;

  std::vector<HeteroGraphPtr> pinned_relation_graphs(relation_graphs.size());
  for (size_t i = 0; i < pinned_relation_graphs.size(); ++i) {
    if (!relation_graphs[i]->IsPinned()) {
      pinned_relation_graphs[i] = relation_graphs[i]->PinMemory();
    } else {
      pinned_relation_graphs[i] = relation_graphs[i];
    }
  }
  return HeteroGraphPtr(new HeteroGraph(
      casted_ptr->meta_graph_, pinned_relation_graphs,
      casted_ptr->num_verts_per_type_));
}

305
void HeteroGraph::PinMemory_() {
306
  for (auto g : relation_graphs_) g->PinMemory_();
307
308
309
}

void HeteroGraph::UnpinMemory_() {
310
  for (auto g : relation_graphs_) g->UnpinMemory_();
311
312
}

313
void HeteroGraph::RecordStream(DGLStreamHandle stream) {
314
  for (auto g : relation_graphs_) g->RecordStream(stream);
315
316
}

317
318
319
320
321
std::string HeteroGraph::SharedMemName() const {
  return shared_mem_ ? shared_mem_->GetName() : "";
}

HeteroGraphPtr HeteroGraph::CopyToSharedMem(
322
323
324
    HeteroGraphPtr g, const std::string& name,
    const std::vector<std::string>& ntypes,
    const std::vector<std::string>& etypes, const std::set<std::string>& fmts) {
325
326
327
  // TODO(JJ): Raise error when calling shared_memory if graph index is on gpu
  auto hg = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hg);
328
  if (hg->SharedMemName() == name) return g;
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

  // Copy buffer to share memory
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->CreateNew(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  bool has_coo = fmts.find("coo") != fmts.end();
  bool has_csr = fmts.find("csr") != fmts.end();
  bool has_csc = fmts.find("csc") != fmts.end();
  shm.Write(g->NumBits());
  shm.Write(has_coo);
  shm.Write(has_csr);
  shm.Write(has_csc);
  shm.Write(ImmutableGraph::ToImmutable(hg->meta_graph_));
  shm.Write(hg->num_verts_per_type_);

  std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());

348
  for (dgl_type_t etype = 0; etype < g->NumEdgeTypes(); ++etype) {
349
350
    auto src_dst_type = g->GetEndpointTypes(etype);
    int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2);
351
352
353
354
355
356
357
358
359
360
361
362
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      coo = shm.CopyToSharedMem(hg->GetCOOMatrix(etype), prefix + "_coo");
    }
    if (has_csr) {
      csr = shm.CopyToSharedMem(hg->GetCSRMatrix(etype), prefix + "_csr");
    }
    if (has_csc) {
      csc = shm.CopyToSharedMem(hg->GetCSCMatrix(etype), prefix + "_csc");
    }
363
364
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
365
366
367
368
369
370
371
372
373
374
375
376
  }

  auto ret = std::shared_ptr<HeteroGraph>(
      new HeteroGraph(hg->meta_graph_, relgraphs, hg->num_verts_per_type_));
  ret->shared_mem_ = mem;

  shm.Write(ntypes);
  shm.Write(etypes);
  return ret;
}

std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
377
HeteroGraph::CreateFromSharedMem(const std::string& name) {
378
379
  bool exist = SharedMemory::Exist(name);
  if (!exist) {
380
381
    return std::make_tuple(
        nullptr, std::vector<std::string>(), std::vector<std::string>());
382
  }
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  uint8_t nbits;
  CHECK(shm.Read(&nbits)) << "invalid nbits (unit8_t)";

  bool has_coo, has_csr, has_csc;
  CHECK(shm.Read(&has_coo)) << "invalid nbits (unit8_t)";
  CHECK(shm.Read(&has_csr)) << "invalid csr (unit8_t)";
  CHECK(shm.Read(&has_csc)) << "invalid csc (unit8_t)";

  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(shm.Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;

  std::vector<int64_t> num_verts_per_type;
  CHECK(shm.Read(&num_verts_per_type)) << "Invalid number of vertices per type";

  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
404
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
405
406
    auto src_dst = metagraph->FindEdge(etype);
    int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2;
407
408
409
410
411
412
413
414
415
416
417
418
419
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      shm.CreateFromSharedMem(&coo, prefix + "_coo");
    }
    if (has_csr) {
      shm.CreateFromSharedMem(&csr, prefix + "_csr");
    }
    if (has_csc) {
      shm.CreateFromSharedMem(&csc, prefix + "_csc");
    }

420
421
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
422
423
  }

424
425
  auto ret =
      std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
426
427
428
429
430
431
432
433
434
  ret->shared_mem_ = mem;

  std::vector<std::string> ntypes;
  std::vector<std::string> etypes;
  CHECK(shm.Read(&ntypes)) << "invalid ntypes";
  CHECK(shm.Read(&etypes)) << "invalid etypes";
  return std::make_tuple(ret, ntypes, etypes);
}

435
HeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const {
436
437
  std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
438
439
    auto relgraph =
        std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
440
    format_rels[etype] = relgraph->GetGraphInFormat(formats);
441
  }
442
443
  return HeteroGraphPtr(
      new HeteroGraph(meta_graph_, format_rels, NumVerticesPerType()));
444
445
}

446
447
448
449
450
FlattenedHeteroGraphPtr HeteroGraph::Flatten(
    const std::vector<dgl_type_t>& etypes) const {
  const int64_t bits = NumBits();
  if (bits == 32) {
    return FlattenImpl<int32_t>(etypes);
451
  } else {
452
453
454
455
456
    return FlattenImpl<int64_t>(etypes);
  }
}

template <class IdType>
457
458
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(
    const std::vector<dgl_type_t>& etypes) const {
Minjie Wang's avatar
Minjie Wang committed
459
460
  std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
  size_t src_nodes = 0, dst_nodes = 0;
461
462
  std::vector<dgl_type_t> induced_srctype, induced_dsttype;
  std::vector<IdType> induced_srcid, induced_dstid;
Minjie Wang's avatar
Minjie Wang committed
463
464
  std::vector<dgl_type_t> srctype_set, dsttype_set;

465
466
  // XXXtype_offsets contain the mapping from node type and number of nodes
  // after this loop.
Minjie Wang's avatar
Minjie Wang committed
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t num_srctype_nodes = NumVertices(srctype);
    size_t num_dsttype_nodes = NumVertices(dsttype);

    if (srctype_offsets.count(srctype) == 0) {
      srctype_offsets[srctype] = num_srctype_nodes;
      srctype_set.push_back(srctype);
    }
    if (dsttype_offsets.count(dsttype) == 0) {
      dsttype_offsets[dsttype] = num_dsttype_nodes;
      dsttype_set.push_back(dsttype);
    }
  }
483
484
  // Sort the node types so that we can compare the sets and decide whether a
  // homogeneous graph should be returned.
Minjie Wang's avatar
Minjie Wang committed
485
486
  std::sort(srctype_set.begin(), srctype_set.end());
  std::sort(dsttype_set.begin(), dsttype_set.end());
487
488
489
  bool homograph =
      (srctype_set.size() == dsttype_set.size()) &&
      std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());
Minjie Wang's avatar
Minjie Wang committed
490

491
492
  // XXXtype_offsets contain the mapping from node type to node ID offsets after
  // these two loops.
Minjie Wang's avatar
Minjie Wang committed
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
  for (size_t i = 0; i < srctype_set.size(); ++i) {
    dgl_type_t ntype = srctype_set[i];
    size_t num_nodes = srctype_offsets[ntype];
    srctype_offsets[ntype] = src_nodes;
    src_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_srctype.push_back(ntype);
      induced_srcid.push_back(j);
    }
  }
  for (size_t i = 0; i < dsttype_set.size(); ++i) {
    dgl_type_t ntype = dsttype_set[i];
    size_t num_nodes = dsttype_offsets[ntype];
    dsttype_offsets[ntype] = dst_nodes;
    dst_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_dsttype.push_back(ntype);
      induced_dstid.push_back(j);
    }
  }
513

514
515
516
517
518
519
520
  // TODO(minjie): Using concat operations cause many fragmented memory.
  //   Need to optimize it in the future.
  std::vector<IdArray> src_arrs, dst_arrs, eid_arrs, induced_etypes;
  src_arrs.reserve(etypes.size());
  dst_arrs.reserve(etypes.size());
  eid_arrs.reserve(etypes.size());
  induced_etypes.reserve(etypes.size());
Minjie Wang's avatar
Minjie Wang committed
521
522
523
524
525
526
527
528
529
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t srctype_offset = srctype_offsets[srctype];
    size_t dsttype_offset = dsttype_offsets[dsttype];

    EdgeArray edges = Edges(etype);
    size_t num_edges = NumEdges(etype);
530
531
532
    src_arrs.push_back(edges.src + srctype_offset);
    dst_arrs.push_back(edges.dst + dsttype_offset);
    eid_arrs.push_back(edges.id);
533
534
    induced_etypes.push_back(
        aten::Full(etype, num_edges, NumBits(), Context()));
Minjie Wang's avatar
Minjie Wang committed
535
536
537
  }

  HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
538
      homograph ? 1 : 2, src_nodes, dst_nodes, aten::Concat(src_arrs),
539
540
541
542
543
      aten::Concat(dst_arrs));

  // Sanity check
  CHECK_EQ(gptr->Context(), Context());
  CHECK_EQ(gptr->NumBits(), NumBits());
Minjie Wang's avatar
Minjie Wang committed
544
545

  FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
546
547
548
549
550
551
  result->graph = HeteroGraphRef(
      HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));
  result->induced_srctype =
      aten::VecToIdArray(induced_srctype).CopyTo(Context());
  result->induced_srctype_set =
      aten::VecToIdArray(srctype_set).CopyTo(Context());
552
553
554
555
  result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());
  result->induced_etype = aten::Concat(induced_etypes);
  result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context());
  result->induced_eid = aten::Concat(eid_arrs);
556
557
558
559
  result->induced_dsttype =
      aten::VecToIdArray(induced_dsttype).CopyTo(Context());
  result->induced_dsttype_set =
      aten::VecToIdArray(dsttype_set).CopyTo(Context());
560
  result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context());
Minjie Wang's avatar
Minjie Wang committed
561
  return FlattenedHeteroGraphPtr(result);
562
563
}

564
565
566
567
568
569
constexpr uint64_t kDGLSerialize_HeteroGraph = 0xDD589FBE35224ABF;

bool HeteroGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
570
571
572
573
574
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
  meta_graph_ = meta_imgraph;
  CHECK(fs->Read(&relation_graphs_)) << "Invalid relation_graphs_";
  CHECK(fs->Read(&num_verts_per_type_)) << "Invalid num_verts_per_type_";
575
576
577
578
579
580
  return true;
}

void HeteroGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_HeteroGraph);
  auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
581
582
583
  fs->Write(meta_graph_ptr);
  fs->Write(relation_graphs_);
  fs->Write(num_verts_per_type_);
584
585
}

586
587
588
GraphPtr HeteroGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "graph has more than one node types";
  CHECK(NumEdgeTypes() == 1) << "graph has more than one edge types";
589
590
  auto unit_graph =
      CHECK_NOTNULL(std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));
591
592
593
  return unit_graph->AsImmutableGraph();
}

594
HeteroGraphPtr HeteroGraph::LineGraph(bool backtracking) const {
595
596
597
598
  CHECK_EQ(1, meta_graph_->NumEdges())
      << "Only support Homogeneous graph now (one edge type)";
  CHECK_EQ(1, meta_graph_->NumVertices())
      << "Only support Homogeneous graph now (one node type)";
599
600
601
  CHECK_EQ(1, relation_graphs_.size()) << "Only support Homogeneous graph now";
  UnitGraphPtr ug = relation_graphs_[0];

602
  const auto& ulg = ug->LineGraph(backtracking);
603
  std::vector<HeteroGraphPtr> rel_graph = {ulg};
604
605
606
607
  std::vector<int64_t> num_nodes_per_type = {
      static_cast<int64_t>(ulg->NumVertices(0))};
  return HeteroGraphPtr(
      new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type)));
608
609
}

610
}  // namespace dgl