heterograph.cc 22.7 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2019 by Contributors
4
5
 * @file graph/heterograph.cc
 * @brief Heterograph implementation
6
 */
sangwzh's avatar
sangwzh committed
7
#include "heterograph.h"
8

Minjie Wang's avatar
Minjie Wang committed
9
#include <dgl/array.h>
10
#include <dgl/graph_serializer.h>
11
#include <dgl/immutable_graph.h>
12
#include <dmlc/memory_io.h>
13

14
#include <memory>
15
16
#include <tuple>
#include <utility>
17
#include <vector>
18
19
20
21
22
23

using namespace dgl::runtime;

namespace dgl {
namespace {

24
25
using dgl::ImmutableGraph;

26
27
28
HeteroSubgraph EdgeSubgraphPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
29
30
      << "Invalid input: the input list size must be the same as the number of "
         "edge type.";
31
32
33
34
35
36
37
38
39
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
  // When preserve_nodes is true, simply compute EdgeSubgraph for each bipartite
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
40
41
    const auto& rel_vsg =
        hg->GetRelationGraph(etype)->EdgeSubgraph({eids[etype]}, true);
42
43
44
45
    subrels[etype] = rel_vsg.graph;
    ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];
    ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];
  }
46
47
  ret.graph = HeteroGraphPtr(
      new HeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType()));
48
49
50
51
52
  return ret;
}

HeteroSubgraph EdgeSubgraphNoPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
53
54
  // TODO(minjie): In general, all relabeling should be separated with subgraph
  //   operations.
55
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
56
57
      << "Invalid input: the input list size must be the same as the number of "
         "edge type.";
58
59
60
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
61
62
63
64
  // NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite
  // complicated in heterograph. This is because we need to make sure bipartite
  // graphs that incident on the same vertex type must have the same ID space.
  // For example, suppose we have following heterograph:
65
66
  //
  // Meta graph: A -> B -> C
Minjie Wang's avatar
Minjie Wang committed
67
  // UnitGraph graphs:
68
69
70
  // * A -> B: (0, 0), (0, 1)
  // * B -> C: (1, 0), (1, 1)
  //
71
72
73
74
  // Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1,
  // 0). We need to make sure that in the result subgraph, node type B still has
  // two nodes. This means we cannot simply compute EdgeSubgraph for B->C which
  // will relabel node#1 of type B to be node #0.
75
76
77
  //
  // One implementation is as follows:
  // (1) For each bipartite graph, slice out the edges using the given eids.
78
79
80
81
  // (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the
  // vertex type
  //     and the value is the incident nodes from the bipartite graphs that has
  //     the vertex type as either srctype or dsttype.
82
83
  // (3) Then for each vertex type, use aten::Relabel_ on its vector<IdArray>.
  //     aten::Relabel_ computes the union of the vertex sets and relabel
84
85
  //     the unique elements from zero. The returned mapping array is the final
  //     induced vertex set for that vertex type.
86
87
88
89
90
91
92
93
94
95
96
97
98
99
  // (4) Use the relabeled edges to construct the bipartite graph.
  // step (1) & (2)
  std::vector<EdgeArray> subedges(hg->NumEdgeTypes());
  std::vector<std::vector<IdArray>> vtype2incnodes(hg->NumVertexTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    auto earray = hg->GetRelationGraph(etype)->FindEdges(0, eids[etype]);
    vtype2incnodes[src_vtype].push_back(earray.src);
    vtype2incnodes[dst_vtype].push_back(earray.dst);
    subedges[etype] = earray;
  }
  // step (3)
100
  std::vector<int64_t> num_vertices_per_type(hg->NumVertexTypes());
101
102
  for (dgl_type_t vtype = 0; vtype < hg->NumVertexTypes(); ++vtype) {
    ret.induced_vertices[vtype] = aten::Relabel_(vtype2incnodes[vtype]);
103
    num_vertices_per_type[vtype] = ret.induced_vertices[vtype]->shape[0];
104
105
106
107
108
109
110
  }
  // step (4)
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
Minjie Wang's avatar
Minjie Wang committed
111
    subrels[etype] = UnitGraph::CreateFromCOO(
112
113
114
115
        (src_vtype == dst_vtype) ? 1 : 2,
        ret.induced_vertices[src_vtype]->shape[0],
        ret.induced_vertices[dst_vtype]->shape[0], subedges[etype].src,
        subedges[etype].dst);
116
  }
117
118
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      hg->meta_graph(), subrels, std::move(num_vertices_per_type)));
119
120
121
  return ret;
}

122
123
void HeteroGraphSanityCheck(
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
124
125
126
  // Sanity check
  CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
  CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
Minjie Wang's avatar
Minjie Wang committed
127
  // all relation graphs must have only one edge type
128
129
130
  for (const auto& rg : rel_graphs) {
    CHECK_EQ(rg->NumEdgeTypes(), 1)
        << "Each relation graph must have only one edge type.";
131
  }
132
  auto ctx = rel_graphs[0]->Context();
133
134
135
  for (const auto& rg : rel_graphs) {
    CHECK_EQ(rg->Context(), ctx)
        << "Each relation graph must have the same context.";
136
  }
137
138
}

139
140
std::vector<int64_t> InferNumVerticesPerType(
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
141
  // create num verts per type
142
  std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);
143
144

  EdgeArray etype_array = meta_graph->Edges();
145
146
147
  dgl_type_t* srctypes = static_cast<dgl_type_t*>(etype_array.src->data);
  dgl_type_t* dsttypes = static_cast<dgl_type_t*>(etype_array.dst->data);
  dgl_type_t* etypes = static_cast<dgl_type_t*>(etype_array.id->data);
148
149
150
151
  for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
    dgl_type_t srctype = srctypes[i];
    dgl_type_t dsttype = dsttypes[i];
    dgl_type_t etype = etypes[i];
Minjie Wang's avatar
Minjie Wang committed
152
153
    const auto& rg = rel_graphs[etype];
    const auto sty = 0;
154
    const auto dty = rg->NumVertexTypes() == 1 ? 0 : 1;
155
156
157
    size_t nv;

    // # nodes of source type
Minjie Wang's avatar
Minjie Wang committed
158
    nv = rg->NumVertices(sty);
159
160
    if (num_verts_per_type[srctype] < 0)
      num_verts_per_type[srctype] = nv;
161
    else
162
      CHECK_EQ(num_verts_per_type[srctype], nv)
163
          << "Mismatch number of vertices for vertex type " << srctype;
164
    // # nodes of destination type
Minjie Wang's avatar
Minjie Wang committed
165
    nv = rg->NumVertices(dty);
166
167
    if (num_verts_per_type[dsttype] < 0)
      num_verts_per_type[dsttype] = nv;
168
    else
169
      CHECK_EQ(num_verts_per_type[dsttype], nv)
170
          << "Mismatch number of vertices for vertex type " << dsttype;
171
  }
172
173
  return num_verts_per_type;
}
174

175
176
std::vector<UnitGraphPtr> CastToUnitGraphs(
    const std::vector<HeteroGraphPtr>& rel_graphs) {
177
  std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());
178
179
180
  for (size_t i = 0; i < rel_graphs.size(); ++i) {
    HeteroGraphPtr relg = rel_graphs[i];
    if (std::dynamic_pointer_cast<UnitGraph>(relg)) {
181
      relation_graphs[i] = std::dynamic_pointer_cast<UnitGraph>(relg);
182
    } else {
183
      relation_graphs[i] = CHECK_NOTNULL(
184
185
186
          std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0)));
    }
  }
187
188
189
190
191
192
  return relation_graphs;
}

}  // namespace

HeteroGraph::HeteroGraph(
193
194
195
    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
    const std::vector<int64_t>& num_nodes_per_type)
    : BaseHeteroGraph(meta_graph) {
196
197
198
199
200
201
  if (num_nodes_per_type.size() == 0)
    num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
  else
    num_verts_per_type_ = num_nodes_per_type;
  HeteroGraphSanityCheck(meta_graph, rel_graphs);
  relation_graphs_ = CastToUnitGraphs(rel_graphs);
202
203
204
}

bool HeteroGraph::IsMultigraph() const {
205
  for (const auto& hg : relation_graphs_) {
206
207
208
209
210
    if (hg->IsMultigraph()) {
      return true;
    }
  }
  return false;
211
212
213
}

BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
214
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
215
216
217
  return aten::LT(vids, NumVertices(vtype));
}

218
219
HeteroSubgraph HeteroGraph::VertexSubgraph(
    const std::vector<IdArray>& vids) const {
220
  CHECK_EQ(vids.size(), NumVertexTypes())
221
222
      << "Invalid input: the input list size must be the same as the number of "
         "vertex types.";
223
224
  HeteroSubgraph ret;
  ret.induced_vertices = vids;
225
226
227
  std::vector<int64_t> num_vertices_per_type(NumVertexTypes());
  for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
    num_vertices_per_type[vtype] = vids[vtype]->shape[0];
228
229
230
231
232
233
  ret.induced_edges.resize(NumEdgeTypes());
  std::vector<HeteroGraphPtr> subrels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
    auto pair = meta_graph_->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
234
235
236
237
    const std::vector<IdArray> rel_vids =
        (src_vtype == dst_vtype)
            ? std::vector<IdArray>({vids[src_vtype]})
            : std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
Minjie Wang's avatar
Minjie Wang committed
238
    const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
239
240
241
    subrels[etype] = rel_vsg.graph;
    ret.induced_edges[etype] = rel_vsg.induced_edges[0];
  }
242
243
  ret.graph = HeteroGraphPtr(
      new HeteroGraph(meta_graph_, subrels, std::move(num_vertices_per_type)));
244
245
246
247
248
249
250
251
252
253
254
255
  return ret;
}

HeteroSubgraph HeteroGraph::EdgeSubgraph(
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
  if (preserve_nodes) {
    return EdgeSubgraphPreserveNodes(this, eids);
  } else {
    return EdgeSubgraphNoPreserveNodes(this, eids);
  }
}

256
257
258
259
260
261
262
HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
    rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));
  }
263
264
  return HeteroGraphPtr(new HeteroGraph(
      hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));
265
266
}

267
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {
268
269
270
271
272
273
274
  if (ctx == g->Context()) {
    return g;
  }
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
275
    rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
276
  }
277
278
  return HeteroGraphPtr(new HeteroGraph(
      hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));
279
280
}

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
HeteroGraphPtr HeteroGraph::PinMemory(HeteroGraphPtr g) {
  auto casted_ptr = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(casted_ptr);
  auto relation_graphs = casted_ptr->relation_graphs_;

  auto it = std::find_if_not(
      relation_graphs.begin(), relation_graphs.end(),
      [](auto& underlying_g) { return underlying_g->IsPinned(); });
  // All underlying relation graphs are pinned, return the input hetero-graph
  // directly.
  if (it == relation_graphs.end()) return g;

  std::vector<HeteroGraphPtr> pinned_relation_graphs(relation_graphs.size());
  for (size_t i = 0; i < pinned_relation_graphs.size(); ++i) {
    if (!relation_graphs[i]->IsPinned()) {
      pinned_relation_graphs[i] = relation_graphs[i]->PinMemory();
    } else {
      pinned_relation_graphs[i] = relation_graphs[i];
    }
  }
  return HeteroGraphPtr(new HeteroGraph(
      casted_ptr->meta_graph_, pinned_relation_graphs,
      casted_ptr->num_verts_per_type_));
}

306
void HeteroGraph::PinMemory_() {
307
  for (auto g : relation_graphs_) g->PinMemory_();
308
309
310
}

void HeteroGraph::UnpinMemory_() {
311
  for (auto g : relation_graphs_) g->UnpinMemory_();
312
313
}

314
void HeteroGraph::RecordStream(DGLStreamHandle stream) {
315
  for (auto g : relation_graphs_) g->RecordStream(stream);
316
317
}

318
319
320
321
322
std::string HeteroGraph::SharedMemName() const {
  return shared_mem_ ? shared_mem_->GetName() : "";
}

HeteroGraphPtr HeteroGraph::CopyToSharedMem(
323
324
325
    HeteroGraphPtr g, const std::string& name,
    const std::vector<std::string>& ntypes,
    const std::vector<std::string>& etypes, const std::set<std::string>& fmts) {
326
327
328
  // TODO(JJ): Raise error when calling shared_memory if graph index is on gpu
  auto hg = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hg);
329
  if (hg->SharedMemName() == name) return g;
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348

  // Copy buffer to share memory
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->CreateNew(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  bool has_coo = fmts.find("coo") != fmts.end();
  bool has_csr = fmts.find("csr") != fmts.end();
  bool has_csc = fmts.find("csc") != fmts.end();
  shm.Write(g->NumBits());
  shm.Write(has_coo);
  shm.Write(has_csr);
  shm.Write(has_csc);
  shm.Write(ImmutableGraph::ToImmutable(hg->meta_graph_));
  shm.Write(hg->num_verts_per_type_);

  std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());

349
  for (dgl_type_t etype = 0; etype < g->NumEdgeTypes(); ++etype) {
350
351
    auto src_dst_type = g->GetEndpointTypes(etype);
    int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2);
352
353
354
355
356
357
358
359
360
361
362
363
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      coo = shm.CopyToSharedMem(hg->GetCOOMatrix(etype), prefix + "_coo");
    }
    if (has_csr) {
      csr = shm.CopyToSharedMem(hg->GetCSRMatrix(etype), prefix + "_csr");
    }
    if (has_csc) {
      csc = shm.CopyToSharedMem(hg->GetCSCMatrix(etype), prefix + "_csc");
    }
364
365
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
366
367
368
369
370
371
372
373
374
375
376
377
  }

  auto ret = std::shared_ptr<HeteroGraph>(
      new HeteroGraph(hg->meta_graph_, relgraphs, hg->num_verts_per_type_));
  ret->shared_mem_ = mem;

  shm.Write(ntypes);
  shm.Write(etypes);
  return ret;
}

std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
378
HeteroGraph::CreateFromSharedMem(const std::string& name) {
379
380
  bool exist = SharedMemory::Exist(name);
  if (!exist) {
381
382
    return std::make_tuple(
        nullptr, std::vector<std::string>(), std::vector<std::string>());
383
  }
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  uint8_t nbits;
  CHECK(shm.Read(&nbits)) << "invalid nbits (unit8_t)";

  bool has_coo, has_csr, has_csc;
  CHECK(shm.Read(&has_coo)) << "invalid nbits (unit8_t)";
  CHECK(shm.Read(&has_csr)) << "invalid csr (unit8_t)";
  CHECK(shm.Read(&has_csc)) << "invalid csc (unit8_t)";

  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(shm.Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;

  std::vector<int64_t> num_verts_per_type;
  CHECK(shm.Read(&num_verts_per_type)) << "Invalid number of vertices per type";

  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
405
  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
406
407
    auto src_dst = metagraph->FindEdge(etype);
    int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2;
408
409
410
411
412
413
414
415
416
417
418
419
420
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      shm.CreateFromSharedMem(&coo, prefix + "_coo");
    }
    if (has_csr) {
      shm.CreateFromSharedMem(&csr, prefix + "_csr");
    }
    if (has_csc) {
      shm.CreateFromSharedMem(&csc, prefix + "_csc");
    }

421
422
    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
423
424
  }

425
426
  auto ret =
      std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
427
428
429
430
431
432
433
434
435
  ret->shared_mem_ = mem;

  std::vector<std::string> ntypes;
  std::vector<std::string> etypes;
  CHECK(shm.Read(&ntypes)) << "invalid ntypes";
  CHECK(shm.Read(&etypes)) << "invalid etypes";
  return std::make_tuple(ret, ntypes, etypes);
}

436
HeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const {
437
438
  std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
439
440
    auto relgraph =
        std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
441
    format_rels[etype] = relgraph->GetGraphInFormat(formats);
442
  }
443
444
  return HeteroGraphPtr(
      new HeteroGraph(meta_graph_, format_rels, NumVerticesPerType()));
445
446
}

447
448
449
450
451
FlattenedHeteroGraphPtr HeteroGraph::Flatten(
    const std::vector<dgl_type_t>& etypes) const {
  const int64_t bits = NumBits();
  if (bits == 32) {
    return FlattenImpl<int32_t>(etypes);
452
  } else {
453
454
455
456
457
    return FlattenImpl<int64_t>(etypes);
  }
}

template <class IdType>
458
459
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(
    const std::vector<dgl_type_t>& etypes) const {
Minjie Wang's avatar
Minjie Wang committed
460
461
  std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
  size_t src_nodes = 0, dst_nodes = 0;
462
463
  std::vector<dgl_type_t> induced_srctype, induced_dsttype;
  std::vector<IdType> induced_srcid, induced_dstid;
Minjie Wang's avatar
Minjie Wang committed
464
465
  std::vector<dgl_type_t> srctype_set, dsttype_set;

466
467
  // XXXtype_offsets contain the mapping from node type and number of nodes
  // after this loop.
Minjie Wang's avatar
Minjie Wang committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t num_srctype_nodes = NumVertices(srctype);
    size_t num_dsttype_nodes = NumVertices(dsttype);

    if (srctype_offsets.count(srctype) == 0) {
      srctype_offsets[srctype] = num_srctype_nodes;
      srctype_set.push_back(srctype);
    }
    if (dsttype_offsets.count(dsttype) == 0) {
      dsttype_offsets[dsttype] = num_dsttype_nodes;
      dsttype_set.push_back(dsttype);
    }
  }
484
485
  // Sort the node types so that we can compare the sets and decide whether a
  // homogeneous graph should be returned.
Minjie Wang's avatar
Minjie Wang committed
486
487
  std::sort(srctype_set.begin(), srctype_set.end());
  std::sort(dsttype_set.begin(), dsttype_set.end());
488
489
490
  bool homograph =
      (srctype_set.size() == dsttype_set.size()) &&
      std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());
Minjie Wang's avatar
Minjie Wang committed
491

492
493
  // XXXtype_offsets contain the mapping from node type to node ID offsets after
  // these two loops.
Minjie Wang's avatar
Minjie Wang committed
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
  for (size_t i = 0; i < srctype_set.size(); ++i) {
    dgl_type_t ntype = srctype_set[i];
    size_t num_nodes = srctype_offsets[ntype];
    srctype_offsets[ntype] = src_nodes;
    src_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_srctype.push_back(ntype);
      induced_srcid.push_back(j);
    }
  }
  for (size_t i = 0; i < dsttype_set.size(); ++i) {
    dgl_type_t ntype = dsttype_set[i];
    size_t num_nodes = dsttype_offsets[ntype];
    dsttype_offsets[ntype] = dst_nodes;
    dst_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_dsttype.push_back(ntype);
      induced_dstid.push_back(j);
    }
  }
514

515
516
517
518
519
520
521
  // TODO(minjie): Using concat operations cause many fragmented memory.
  //   Need to optimize it in the future.
  std::vector<IdArray> src_arrs, dst_arrs, eid_arrs, induced_etypes;
  src_arrs.reserve(etypes.size());
  dst_arrs.reserve(etypes.size());
  eid_arrs.reserve(etypes.size());
  induced_etypes.reserve(etypes.size());
Minjie Wang's avatar
Minjie Wang committed
522
523
524
525
526
527
528
529
530
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t srctype_offset = srctype_offsets[srctype];
    size_t dsttype_offset = dsttype_offsets[dsttype];

    EdgeArray edges = Edges(etype);
    size_t num_edges = NumEdges(etype);
531
532
533
    src_arrs.push_back(edges.src + srctype_offset);
    dst_arrs.push_back(edges.dst + dsttype_offset);
    eid_arrs.push_back(edges.id);
534
535
    induced_etypes.push_back(
        aten::Full(etype, num_edges, NumBits(), Context()));
Minjie Wang's avatar
Minjie Wang committed
536
537
538
  }

  HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
539
      homograph ? 1 : 2, src_nodes, dst_nodes, aten::Concat(src_arrs),
540
541
542
543
544
      aten::Concat(dst_arrs));

  // Sanity check
  CHECK_EQ(gptr->Context(), Context());
  CHECK_EQ(gptr->NumBits(), NumBits());
Minjie Wang's avatar
Minjie Wang committed
545
546

  FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
547
548
549
550
551
552
  result->graph = HeteroGraphRef(
      HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));
  result->induced_srctype =
      aten::VecToIdArray(induced_srctype).CopyTo(Context());
  result->induced_srctype_set =
      aten::VecToIdArray(srctype_set).CopyTo(Context());
553
554
555
556
  result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());
  result->induced_etype = aten::Concat(induced_etypes);
  result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context());
  result->induced_eid = aten::Concat(eid_arrs);
557
558
559
560
  result->induced_dsttype =
      aten::VecToIdArray(induced_dsttype).CopyTo(Context());
  result->induced_dsttype_set =
      aten::VecToIdArray(dsttype_set).CopyTo(Context());
561
  result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context());
Minjie Wang's avatar
Minjie Wang committed
562
  return FlattenedHeteroGraphPtr(result);
563
564
}

565
566
567
568
569
570
constexpr uint64_t kDGLSerialize_HeteroGraph = 0xDD589FBE35224ABF;

bool HeteroGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
571
572
573
574
575
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
  meta_graph_ = meta_imgraph;
  CHECK(fs->Read(&relation_graphs_)) << "Invalid relation_graphs_";
  CHECK(fs->Read(&num_verts_per_type_)) << "Invalid num_verts_per_type_";
576
577
578
579
580
581
  return true;
}

void HeteroGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_HeteroGraph);
  auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
582
583
584
  fs->Write(meta_graph_ptr);
  fs->Write(relation_graphs_);
  fs->Write(num_verts_per_type_);
585
586
}

587
588
589
GraphPtr HeteroGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "graph has more than one node types";
  CHECK(NumEdgeTypes() == 1) << "graph has more than one edge types";
590
591
  auto unit_graph =
      CHECK_NOTNULL(std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));
592
593
594
  return unit_graph->AsImmutableGraph();
}

595
HeteroGraphPtr HeteroGraph::LineGraph(bool backtracking) const {
596
597
598
599
  CHECK_EQ(1, meta_graph_->NumEdges())
      << "Only support Homogeneous graph now (one edge type)";
  CHECK_EQ(1, meta_graph_->NumVertices())
      << "Only support Homogeneous graph now (one node type)";
600
601
602
  CHECK_EQ(1, relation_graphs_.size()) << "Only support Homogeneous graph now";
  UnitGraphPtr ug = relation_graphs_[0];

603
  const auto& ulg = ug->LineGraph(backtracking);
604
  std::vector<HeteroGraphPtr> rel_graph = {ulg};
605
606
607
608
  std::vector<int64_t> num_nodes_per_type = {
      static_cast<int64_t>(ulg->NumVertices(0))};
  return HeteroGraphPtr(
      new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type)));
609
610
}

611
}  // namespace dgl