heterograph.cc 21.1 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2019 by Contributors
 * \file graph/heterograph.cc
 * \brief Heterograph implementation
 */
#include "./heterograph.h"
Minjie Wang's avatar
Minjie Wang committed
7
#include <dgl/array.h>
8
#include <dgl/immutable_graph.h>
9
#include <dgl/graph_serializer.h>
10
11
#include <dmlc/memory_io.h>
#include <memory>
12
13
14
#include <vector>
#include <tuple>
#include <utility>
15
16
17
18
19
20

using namespace dgl::runtime;

namespace dgl {
namespace {

21
22
using dgl::ImmutableGraph;

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
HeteroSubgraph EdgeSubgraphPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
    << "Invalid input: the input list size must be the same as the number of edge type.";
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
  // When preserve_nodes is true, simply compute EdgeSubgraph for each bipartite
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    const auto& rel_vsg = hg->GetRelationGraph(etype)->EdgeSubgraph(
        {eids[etype]}, true);
    subrels[etype] = rel_vsg.graph;
    ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];
    ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];
  }
42
43
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      hg->meta_graph(), subrels, hg->NumVerticesPerType()));
44
45
46
47
48
  return ret;
}

HeteroSubgraph EdgeSubgraphNoPreserveNodes(
    const HeteroGraph* hg, const std::vector<IdArray>& eids) {
49
50
  // TODO(minjie): In general, all relabeling should be separated with subgraph
  //   operations.
51
52
53
54
55
56
57
58
59
60
61
  CHECK_EQ(eids.size(), hg->NumEdgeTypes())
    << "Invalid input: the input list size must be the same as the number of edge type.";
  HeteroSubgraph ret;
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = eids;
  // NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite complicated in
  // heterograph. This is because we need to make sure bipartite graphs that incident
  // on the same vertex type must have the same ID space. For example, suppose we have
  // following heterograph:
  //
  // Meta graph: A -> B -> C
Minjie Wang's avatar
Minjie Wang committed
62
  // UnitGraph graphs:
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  // * A -> B: (0, 0), (0, 1)
  // * B -> C: (1, 0), (1, 1)
  //
  // Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1, 0). We need
  // to make sure that in the result subgraph, node type B still has two nodes. This means
  // we cannot simply compute EdgeSubgraph for B->C which will relabel node#1 of type B to be
  // node #0.
  //
  // One implementation is as follows:
  // (1) For each bipartite graph, slice out the edges using the given eids.
  // (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the vertex type
  //     and the value is the incident nodes from the bipartite graphs that has the vertex
  //     type as either srctype or dsttype.
  // (3) Then for each vertex type, use aten::Relabel_ on its vector<IdArray>.
  //     aten::Relabel_ computes the union of the vertex sets and relabel
  //     the unique elements from zero. The returned mapping array is the final induced
  //     vertex set for that vertex type.
  // (4) Use the relabeled edges to construct the bipartite graph.
  // step (1) & (2)
  std::vector<EdgeArray> subedges(hg->NumEdgeTypes());
  std::vector<std::vector<IdArray>> vtype2incnodes(hg->NumVertexTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    auto earray = hg->GetRelationGraph(etype)->FindEdges(0, eids[etype]);
    vtype2incnodes[src_vtype].push_back(earray.src);
    vtype2incnodes[dst_vtype].push_back(earray.dst);
    subedges[etype] = earray;
  }
  // step (3)
94
  std::vector<int64_t> num_vertices_per_type(hg->NumVertexTypes());
95
96
  for (dgl_type_t vtype = 0; vtype < hg->NumVertexTypes(); ++vtype) {
    ret.induced_vertices[vtype] = aten::Relabel_(vtype2incnodes[vtype]);
97
    num_vertices_per_type[vtype] = ret.induced_vertices[vtype]->shape[0];
98
99
100
101
102
103
104
  }
  // step (4)
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
Minjie Wang's avatar
Minjie Wang committed
105
106
    subrels[etype] = UnitGraph::CreateFromCOO(
      (src_vtype == dst_vtype)? 1 : 2,
107
108
109
110
111
      ret.induced_vertices[src_vtype]->shape[0],
      ret.induced_vertices[dst_vtype]->shape[0],
      subedges[etype].src,
      subedges[etype].dst);
  }
112
113
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      hg->meta_graph(), subrels, std::move(num_vertices_per_type)));
114
115
116
  return ret;
}

117
void HeteroGraphSanityCheck(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
118
119
120
  // Sanity check
  CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
  CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
Minjie Wang's avatar
Minjie Wang committed
121
  // all relation graphs must have only one edge type
122
  for (const auto &rg : rel_graphs) {
Minjie Wang's avatar
Minjie Wang committed
123
    CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
124
  }
125
126
127
128
  auto ctx = rel_graphs[0]->Context();
  for (const auto &rg : rel_graphs) {
    CHECK_EQ(rg->Context(), ctx) << "Each relation graph must have the same context.";
  }
129
130
131
132
}

std::vector<int64_t>
InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
133
  // create num verts per type
134
  std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);
135
136
137
138
139
140
141
142
143

  EdgeArray etype_array = meta_graph->Edges();
  dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data);
  dgl_type_t *dsttypes = static_cast<dgl_type_t *>(etype_array.dst->data);
  dgl_type_t *etypes = static_cast<dgl_type_t *>(etype_array.id->data);
  for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
    dgl_type_t srctype = srctypes[i];
    dgl_type_t dsttype = dsttypes[i];
    dgl_type_t etype = etypes[i];
Minjie Wang's avatar
Minjie Wang committed
144
145
146
    const auto& rg = rel_graphs[etype];
    const auto sty = 0;
    const auto dty = rg->NumVertexTypes() == 1? 0 : 1;
147
148
149
    size_t nv;

    // # nodes of source type
Minjie Wang's avatar
Minjie Wang committed
150
    nv = rg->NumVertices(sty);
151
152
    if (num_verts_per_type[srctype] < 0)
      num_verts_per_type[srctype] = nv;
153
    else
154
      CHECK_EQ(num_verts_per_type[srctype], nv)
155
156
        << "Mismatch number of vertices for vertex type " << srctype;
    // # nodes of destination type
Minjie Wang's avatar
Minjie Wang committed
157
    nv = rg->NumVertices(dty);
158
159
    if (num_verts_per_type[dsttype] < 0)
      num_verts_per_type[dsttype] = nv;
160
    else
161
      CHECK_EQ(num_verts_per_type[dsttype], nv)
162
        << "Mismatch number of vertices for vertex type " << dsttype;
163
  }
164
165
  return num_verts_per_type;
}
166

167
168
std::vector<UnitGraphPtr> CastToUnitGraphs(const std::vector<HeteroGraphPtr>& rel_graphs) {
  std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());
169
170
171
  for (size_t i = 0; i < rel_graphs.size(); ++i) {
    HeteroGraphPtr relg = rel_graphs[i];
    if (std::dynamic_pointer_cast<UnitGraph>(relg)) {
172
      relation_graphs[i] = std::dynamic_pointer_cast<UnitGraph>(relg);
173
    } else {
174
      relation_graphs[i] = CHECK_NOTNULL(
175
176
177
          std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0)));
    }
  }
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
  return relation_graphs;
}

}  // namespace

HeteroGraph::HeteroGraph(
    GraphPtr meta_graph,
    const std::vector<HeteroGraphPtr>& rel_graphs,
    const std::vector<int64_t>& num_nodes_per_type) : BaseHeteroGraph(meta_graph) {
  if (num_nodes_per_type.size() == 0)
    num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
  else
    num_verts_per_type_ = num_nodes_per_type;
  HeteroGraphSanityCheck(meta_graph, rel_graphs);
  relation_graphs_ = CastToUnitGraphs(rel_graphs);
193
194
195
}

bool HeteroGraph::IsMultigraph() const {
196
197
198
199
200
201
  for (const auto &hg : relation_graphs_) {
    if (hg->IsMultigraph()) {
      return true;
    }
  }
  return false;
202
203
204
}

BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
205
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
206
207
208
209
210
211
212
213
  return aten::LT(vids, NumVertices(vtype));
}

HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
  CHECK_EQ(vids.size(), NumVertexTypes())
    << "Invalid input: the input list size must be the same as the number of vertex types.";
  HeteroSubgraph ret;
  ret.induced_vertices = vids;
214
215
216
  std::vector<int64_t> num_vertices_per_type(NumVertexTypes());
  for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
    num_vertices_per_type[vtype] = vids[vtype]->shape[0];
217
218
219
220
221
222
  ret.induced_edges.resize(NumEdgeTypes());
  std::vector<HeteroGraphPtr> subrels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
    auto pair = meta_graph_->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
Minjie Wang's avatar
Minjie Wang committed
223
224
225
226
    const std::vector<IdArray> rel_vids = (src_vtype == dst_vtype) ?
      std::vector<IdArray>({vids[src_vtype]}) :
      std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
    const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
227
228
229
    subrels[etype] = rel_vsg.graph;
    ret.induced_edges[etype] = rel_vsg.induced_edges[0];
  }
230
231
  ret.graph = HeteroGraphPtr(new HeteroGraph(
      meta_graph_, subrels, std::move(num_vertices_per_type)));
232
233
234
235
236
237
238
239
240
241
242
243
  return ret;
}

HeteroSubgraph HeteroGraph::EdgeSubgraph(
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
  if (preserve_nodes) {
    return EdgeSubgraphPreserveNodes(this, eids);
  } else {
    return EdgeSubgraphNoPreserveNodes(this, eids);
  }
}

244
245
246
247
248
249
250
251
252
253
254
HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
    rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));
  }
  return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
                                        hgindex->num_verts_per_type_));
}

255
256
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
                                   const DGLStreamHandle &stream) {
257
258
259
260
261
262
263
  if (ctx == g->Context()) {
    return g;
  }
  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hgindex);
  std::vector<HeteroGraphPtr> rel_graphs;
  for (auto g : hgindex->relation_graphs_) {
264
    rel_graphs.push_back(UnitGraph::CopyTo(g, ctx, stream));
265
266
267
268
269
  }
  return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
                                        hgindex->num_verts_per_type_));
}

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
std::string HeteroGraph::SharedMemName() const {
  return shared_mem_ ? shared_mem_->GetName() : "";
}

HeteroGraphPtr HeteroGraph::CopyToSharedMem(
      HeteroGraphPtr g, const std::string& name, const std::vector<std::string>& ntypes,
      const std::vector<std::string>& etypes, const std::set<std::string>& fmts) {
  // TODO(JJ): Raise error when calling shared_memory if graph index is on gpu
  auto hg = std::dynamic_pointer_cast<HeteroGraph>(g);
  CHECK_NOTNULL(hg);
  if (hg->SharedMemName() == name)
    return g;

  // Copy buffer to share memory
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->CreateNew(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  bool has_coo = fmts.find("coo") != fmts.end();
  bool has_csr = fmts.find("csr") != fmts.end();
  bool has_csc = fmts.find("csc") != fmts.end();
  shm.Write(g->NumBits());
  shm.Write(has_coo);
  shm.Write(has_csr);
  shm.Write(has_csc);
  shm.Write(ImmutableGraph::ToImmutable(hg->meta_graph_));
  shm.Write(hg->num_verts_per_type_);

  std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());

  for (dgl_type_t etype = 0 ; etype < g->NumEdgeTypes() ; ++etype) {
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      coo = shm.CopyToSharedMem(hg->GetCOOMatrix(etype), prefix + "_coo");
    }
    if (has_csr) {
      csr = shm.CopyToSharedMem(hg->GetCSRMatrix(etype), prefix + "_csr");
    }
    if (has_csc) {
      csc = shm.CopyToSharedMem(hg->GetCSCMatrix(etype), prefix + "_csc");
    }
    relgraphs[etype] = UnitGraph::CreateHomographFrom(csc, csr, coo, has_csc, has_csr, has_coo);
  }

  auto ret = std::shared_ptr<HeteroGraph>(
      new HeteroGraph(hg->meta_graph_, relgraphs, hg->num_verts_per_type_));
  ret->shared_mem_ = mem;

  shm.Write(ntypes);
  shm.Write(etypes);
  return ret;
}

std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
    HeteroGraph::CreateFromSharedMem(const std::string &name) {
328
329
330
331
  bool exist = SharedMemory::Exist(name);
  if (!exist) {
    return std::make_tuple(nullptr, std::vector<std::string>(), std::vector<std::string>());
  }
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
  auto mem = std::make_shared<SharedMemory>(name);
  auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX);
  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);
  SharedMemManager shm(name, &strm);

  uint8_t nbits;
  CHECK(shm.Read(&nbits)) << "invalid nbits (unit8_t)";

  bool has_coo, has_csr, has_csc;
  CHECK(shm.Read(&has_coo)) << "invalid nbits (unit8_t)";
  CHECK(shm.Read(&has_csr)) << "invalid csr (unit8_t)";
  CHECK(shm.Read(&has_csc)) << "invalid csc (unit8_t)";

  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(shm.Read(&meta_imgraph)) << "Invalid meta graph";
  GraphPtr metagraph = meta_imgraph;

  std::vector<int64_t> num_verts_per_type;
  CHECK(shm.Read(&num_verts_per_type)) << "Invalid number of vertices per type";

  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
  for (dgl_type_t etype = 0 ; etype < metagraph->NumEdges() ; ++etype) {
    aten::COOMatrix coo;
    aten::CSRMatrix csr, csc;
    std::string prefix = name + "_" + std::to_string(etype);
    if (has_coo) {
      shm.CreateFromSharedMem(&coo, prefix + "_coo");
    }
    if (has_csr) {
      shm.CreateFromSharedMem(&csr, prefix + "_csr");
    }
    if (has_csc) {
      shm.CreateFromSharedMem(&csc, prefix + "_csc");
    }

    relgraphs[etype] = UnitGraph::CreateHomographFrom(csc, csr, coo, has_csc, has_csr, has_coo);
  }

  auto ret = std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
  ret->shared_mem_ = mem;

  std::vector<std::string> ntypes;
  std::vector<std::string> etypes;
  CHECK(shm.Read(&ntypes)) << "invalid ntypes";
  CHECK(shm.Read(&etypes)) << "invalid etypes";
  return std::make_tuple(ret, ntypes, etypes);
}

380
HeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const {
381
382
383
  std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
    auto relgraph = std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
384
    format_rels[etype] = relgraph->GetGraphInFormat(formats);
385
386
387
388
389
  }
  return HeteroGraphPtr(new HeteroGraph(
    meta_graph_, format_rels, NumVerticesPerType()));
}

390
391
392
393
394
FlattenedHeteroGraphPtr HeteroGraph::Flatten(
    const std::vector<dgl_type_t>& etypes) const {
  const int64_t bits = NumBits();
  if (bits == 32) {
    return FlattenImpl<int32_t>(etypes);
395
  } else {
396
397
398
399
400
401
    return FlattenImpl<int64_t>(etypes);
  }
}

template <class IdType>
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& etypes) const {
Minjie Wang's avatar
Minjie Wang committed
402
403
  std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
  size_t src_nodes = 0, dst_nodes = 0;
404
405
  std::vector<dgl_type_t> induced_srctype, induced_dsttype;
  std::vector<IdType> induced_srcid, induced_dstid;
Minjie Wang's avatar
Minjie Wang committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
  std::vector<dgl_type_t> srctype_set, dsttype_set;

  // XXXtype_offsets contain the mapping from node type and number of nodes after this
  // loop.
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t num_srctype_nodes = NumVertices(srctype);
    size_t num_dsttype_nodes = NumVertices(dsttype);

    if (srctype_offsets.count(srctype) == 0) {
      srctype_offsets[srctype] = num_srctype_nodes;
      srctype_set.push_back(srctype);
    }
    if (dsttype_offsets.count(dsttype) == 0) {
      dsttype_offsets[dsttype] = num_dsttype_nodes;
      dsttype_set.push_back(dsttype);
    }
  }
426
  // Sort the node types so that we can compare the sets and decide whether a homogeneous graph
Minjie Wang's avatar
Minjie Wang committed
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
  // should be returned.
  std::sort(srctype_set.begin(), srctype_set.end());
  std::sort(dsttype_set.begin(), dsttype_set.end());
  bool homograph = (srctype_set.size() == dsttype_set.size()) &&
    std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());

  // XXXtype_offsets contain the mapping from node type to node ID offsets after these
  // two loops.
  for (size_t i = 0; i < srctype_set.size(); ++i) {
    dgl_type_t ntype = srctype_set[i];
    size_t num_nodes = srctype_offsets[ntype];
    srctype_offsets[ntype] = src_nodes;
    src_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_srctype.push_back(ntype);
      induced_srcid.push_back(j);
    }
  }
  for (size_t i = 0; i < dsttype_set.size(); ++i) {
    dgl_type_t ntype = dsttype_set[i];
    size_t num_nodes = dsttype_offsets[ntype];
    dsttype_offsets[ntype] = dst_nodes;
    dst_nodes += num_nodes;
    for (size_t j = 0; j < num_nodes; ++j) {
      induced_dsttype.push_back(ntype);
      induced_dstid.push_back(j);
    }
  }
455

456
457
458
459
460
461
462
  // TODO(minjie): Using concat operations cause many fragmented memory.
  //   Need to optimize it in the future.
  std::vector<IdArray> src_arrs, dst_arrs, eid_arrs, induced_etypes;
  src_arrs.reserve(etypes.size());
  dst_arrs.reserve(etypes.size());
  eid_arrs.reserve(etypes.size());
  induced_etypes.reserve(etypes.size());
Minjie Wang's avatar
Minjie Wang committed
463
464
465
466
467
468
469
470
471
  for (dgl_type_t etype : etypes) {
    auto src_dsttype = meta_graph_->FindEdge(etype);
    dgl_type_t srctype = src_dsttype.first;
    dgl_type_t dsttype = src_dsttype.second;
    size_t srctype_offset = srctype_offsets[srctype];
    size_t dsttype_offset = dsttype_offsets[dsttype];

    EdgeArray edges = Edges(etype);
    size_t num_edges = NumEdges(etype);
472
473
474
475
    src_arrs.push_back(edges.src + srctype_offset);
    dst_arrs.push_back(edges.dst + dsttype_offset);
    eid_arrs.push_back(edges.id);
    induced_etypes.push_back(aten::Full(etype, num_edges, NumBits(), Context()));
Minjie Wang's avatar
Minjie Wang committed
476
477
478
479
480
481
  }

  HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
      homograph ? 1 : 2,
      src_nodes,
      dst_nodes,
482
483
484
485
486
487
      aten::Concat(src_arrs),
      aten::Concat(dst_arrs));

  // Sanity check
  CHECK_EQ(gptr->Context(), Context());
  CHECK_EQ(gptr->NumBits(), NumBits());
Minjie Wang's avatar
Minjie Wang committed
488
489

  FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
490
  result->graph = HeteroGraphRef(HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));
491
492
493
494
495
496
497
498
499
  result->induced_srctype = aten::VecToIdArray(induced_srctype).CopyTo(Context());
  result->induced_srctype_set = aten::VecToIdArray(srctype_set).CopyTo(Context());
  result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());
  result->induced_etype = aten::Concat(induced_etypes);
  result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context());
  result->induced_eid = aten::Concat(eid_arrs);
  result->induced_dsttype = aten::VecToIdArray(induced_dsttype).CopyTo(Context());
  result->induced_dsttype_set = aten::VecToIdArray(dsttype_set).CopyTo(Context());
  result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context());
Minjie Wang's avatar
Minjie Wang committed
500
  return FlattenedHeteroGraphPtr(result);
501
502
}

503
504
505
506
507
508
constexpr uint64_t kDGLSerialize_HeteroGraph = 0xDD589FBE35224ABF;

bool HeteroGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
509
510
511
512
513
  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
  CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
  meta_graph_ = meta_imgraph;
  CHECK(fs->Read(&relation_graphs_)) << "Invalid relation_graphs_";
  CHECK(fs->Read(&num_verts_per_type_)) << "Invalid num_verts_per_type_";
514
515
516
517
518
519
  return true;
}

void HeteroGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_HeteroGraph);
  auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
520
521
522
  fs->Write(meta_graph_ptr);
  fs->Write(relation_graphs_);
  fs->Write(num_verts_per_type_);
523
524
}

525
526
527
528
529
530
531
532
GraphPtr HeteroGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "graph has more than one node types";
  CHECK(NumEdgeTypes() == 1) << "graph has more than one edge types";
  auto unit_graph = CHECK_NOTNULL(
      std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));
  return unit_graph->AsImmutableGraph();
}

533
534
535
536
537
538
539
540
541
542
543
544
HeteroGraphPtr HeteroGraph::LineGraph(bool backtracking) const {
  CHECK_EQ(1, meta_graph_->NumEdges()) << "Only support Homogeneous graph now (one edge type)";
  CHECK_EQ(1, meta_graph_->NumVertices()) << "Only support Homogeneous graph now (one node type)";
  CHECK_EQ(1, relation_graphs_.size()) << "Only support Homogeneous graph now";
  UnitGraphPtr ug = relation_graphs_[0];

  const auto &ulg = ug->LineGraph(backtracking);
  std::vector<HeteroGraphPtr> rel_graph = {ulg};
  std::vector<int64_t> num_nodes_per_type = {static_cast<int64_t>(ulg->NumVertices(0))};
  return HeteroGraphPtr(new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type)));
}

545
}  // namespace dgl