heterograph.h 7.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/*!
 *  Copyright (c) 2019 by Contributors
 * \file graph/heterograph.h
 * \brief Heterograph
 */

#ifndef DGL_GRAPH_HETEROGRAPH_H_
#define DGL_GRAPH_HETEROGRAPH_H_

#include <dgl/base_heterograph.h>
#include <dgl/lazy.h>
#include <utility>
#include <string>
#include <vector>
15
#include "./unit_graph.h"
16
17
18
19
20
21

namespace dgl {

/*! \brief Heterograph */
class HeteroGraph : public BaseHeteroGraph {
 public:
22
23
24
25
  HeteroGraph(
      GraphPtr meta_graph,
      const std::vector<HeteroGraphPtr>& rel_graphs,
      const std::vector<int64_t>& num_nodes_per_type = {});
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
    CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype;
    return relation_graphs_[etype];
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

  void Clear() override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

48
49
50
51
  DLDataType DataType() const override {
    return relation_graphs_[0]->DataType();
  }

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  DLContext Context() const override {
    return relation_graphs_[0]->Context();
  }

  uint8_t NumBits() const override {
    return relation_graphs_[0]->NumBits();
  }

  bool IsMultigraph() const override;

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
    CHECK(meta_graph_->HasVertex(vtype)) << "Invalid vertex type: " << vtype;
    return num_verts_per_type_[vtype];
  }

71
72
73
74
  inline std::vector<int64_t> NumVerticesPerType() const override {
    return num_verts_per_type_;
  }

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
  uint64_t NumEdges(dgl_type_t etype) const override {
    return GetRelationGraph(etype)->NumEdges(0);
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
    return GetRelationGraph(etype)->HasEdgeBetween(0, src, dst);
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
    return GetRelationGraph(etype)->HasEdgesBetween(0, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    return GetRelationGraph(etype)->Predecessors(0, dst);
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
    return GetRelationGraph(etype)->Successors(0, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
    return GetRelationGraph(etype)->EdgeId(0, src, dst);
  }

  EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override {
    return GetRelationGraph(etype)->EdgeIds(0, src, dst);
  }

  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
    return GetRelationGraph(etype)->FindEdge(0, eid);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
    return GetRelationGraph(etype)->FindEdges(0, eids);
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->InEdges(0, vid);
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
    return GetRelationGraph(etype)->InEdges(0, vids);
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->OutEdges(0, vid);
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
    return GetRelationGraph(etype)->OutEdges(0, vids);
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    return GetRelationGraph(etype)->Edges(0, order);
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->InDegree(0, vid);
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
    return GetRelationGraph(etype)->InDegrees(0, vids);
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->OutDegree(0, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
    return GetRelationGraph(etype)->OutDegrees(0, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->SuccVec(0, vid);
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->OutEdgeVec(0, vid);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->PredVec(0, vid);
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->InEdgeVec(0, vid);
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    return GetRelationGraph(etype)->GetAdj(0, transpose, fmt);
  }

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    return GetRelationGraph(etype)->GetCOOMatrix(0);
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    return GetRelationGraph(etype)->GetCSCMatrix(0);
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    return GetRelationGraph(etype)->GetCSRMatrix(0);
  }

  SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
    return GetRelationGraph(etype)->SelectFormat(0, preferred_format);
  }

190
191
192
193
194
195
196
197
198
199
  std::string GetRestrictFormat() const override {
    LOG(FATAL) << "Not enabled for hetero graph (with multiple relations)";
    return std::string("");
  }

  dgl_format_code_t GetFormatInUse() const override {
    LOG(FATAL) << "Not enabled for hetero graph (with multiple relations)";
    return 0;
  }

200
201
202
203
204
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;

205
206
  HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const override;

Minjie Wang's avatar
Minjie Wang committed
207
208
  FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;

209
210
  GraphPtr AsImmutableGraph() const override;

211
212
213
214
215
216
  /*! \return Load HeteroGraph from stream, using CSRMatrix*/
  bool Load(dmlc::Stream* fs);

  /*! \return Save HeteroGraph to stream, using CSRMatrix */
  void Save(dmlc::Stream* fs) const;

217
218
219
  /*! \brief Convert the graph to use the given number of bits for storage */
  static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);

220
221
222
  /*! \brief Copy the data to another context */
  static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);

223
224
225
226

  /*! \brief Creat a LineGraph of self */
  HeteroGraphPtr LineGraph(bool backtracking) const;

227
228
229
230
  const std::vector<UnitGraphPtr>& relation_graphs() const {
    return relation_graphs_;
  }

231
 private:
232
233
234
235
  // To create empty class
  friend class Serializer;

  // Empty Constructor, only for serializer
236
  HeteroGraph() : BaseHeteroGraph() {}
237

Minjie Wang's avatar
Minjie Wang committed
238
  /*! \brief A map from edge type to unit graph */
239
  std::vector<UnitGraphPtr> relation_graphs_;
240
241
242

  /*! \brief A map from vert type to the number of verts in the type */
  std::vector<int64_t> num_verts_per_type_;
243
244
245
246
247
248
249
250
251

  /*! \brief template class for Flatten operation
  * 
  * \tparam IdType Graph's index data type, can be int32_t or int64_t
  * \param etypes vector of etypes to be falttened
  * \return pointer of FlattenedHeteroGraphh
  */
  template <class IdType>
  FlattenedHeteroGraphPtr FlattenImpl(const std::vector<dgl_type_t>& etypes) const;
252
253
254
255
};

}  // namespace dgl

256
257
258
259
260
261

namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::HeteroGraph, true);
}  // namespace dmlc


262
#endif  // DGL_GRAPH_HETEROGRAPH_H_