heterograph.h 10.3 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2019 by Contributors
4
5
 * @file graph/heterograph.h
 * @brief Heterograph
6
7
8
9
10
11
12
 */

#ifndef DGL_GRAPH_HETEROGRAPH_H_
#define DGL_GRAPH_HETEROGRAPH_H_

#include <dgl/base_heterograph.h>
#include <dgl/lazy.h>
13
14
15
#include <dgl/runtime/shared_mem.h>

#include <memory>
16
#include <set>
17
#include <string>
18
#include <tuple>
19
20
21
#include <utility>
#include <vector>

sangwzh's avatar
sangwzh committed
22
#include "unit_graph.h"
23
#include "shared_mem_manager.h"
24
25
26

namespace dgl {

27
/** @brief Heterograph */
28
29
class HeteroGraph : public BaseHeteroGraph {
 public:
30
  HeteroGraph(
31
      GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
32
      const std::vector<int64_t>& num_nodes_per_type = {});
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
    CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype;
    return relation_graphs_[etype];
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

51
  void Clear() override { LOG(FATAL) << "Bipartite graph is not mutable."; }
52

53
  DGLDataType DataType() const override {
54
55
56
    return relation_graphs_[0]->DataType();
  }

57
  DGLContext Context() const override { return relation_graphs_[0]->Context(); }
58

59
  bool IsPinned() const override { return relation_graphs_[0]->IsPinned(); }
60

61
  uint8_t NumBits() const override { return relation_graphs_[0]->NumBits(); }
62
63
64

  bool IsMultigraph() const override;

65
  bool IsReadonly() const override { return true; }
66
67
68
69
70
71

  uint64_t NumVertices(dgl_type_t vtype) const override {
    CHECK(meta_graph_->HasVertex(vtype)) << "Invalid vertex type: " << vtype;
    return num_verts_per_type_[vtype];
  }

72
73
74
75
  inline std::vector<int64_t> NumVerticesPerType() const override {
    return num_verts_per_type_;
  }

76
77
78
79
80
81
82
83
84
85
  uint64_t NumEdges(dgl_type_t etype) const override {
    return GetRelationGraph(etype)->NumEdges(0);
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;

86
87
  bool HasEdgeBetween(
      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
88
89
90
    return GetRelationGraph(etype)->HasEdgeBetween(0, src, dst);
  }

91
92
  BoolArray HasEdgesBetween(
      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    return GetRelationGraph(etype)->HasEdgesBetween(0, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    return GetRelationGraph(etype)->Predecessors(0, dst);
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
    return GetRelationGraph(etype)->Successors(0, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
    return GetRelationGraph(etype)->EdgeId(0, src, dst);
  }

108
109
  EdgeArray EdgeIdsAll(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
110
111
112
    return GetRelationGraph(etype)->EdgeIdsAll(0, src, dst);
  }

113
114
  IdArray EdgeIdsOne(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
115
    return GetRelationGraph(etype)->EdgeIdsOne(0, src, dst);
116
117
  }

118
119
  std::pair<dgl_id_t, dgl_id_t> FindEdge(
      dgl_type_t etype, dgl_id_t eid) const override {
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    return GetRelationGraph(etype)->FindEdge(0, eid);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
    return GetRelationGraph(etype)->FindEdges(0, eids);
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->InEdges(0, vid);
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
    return GetRelationGraph(etype)->InEdges(0, vids);
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->OutEdges(0, vid);
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
    return GetRelationGraph(etype)->OutEdges(0, vids);
  }

143
144
  EdgeArray Edges(
      dgl_type_t etype, const std::string& order = "") const override {
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    return GetRelationGraph(etype)->Edges(0, order);
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->InDegree(0, vid);
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
    return GetRelationGraph(etype)->InDegrees(0, vids);
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->OutDegree(0, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
    return GetRelationGraph(etype)->OutDegrees(0, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->SuccVec(0, vid);
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->OutEdgeVec(0, vid);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->PredVec(0, vid);
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    return GetRelationGraph(etype)->InEdgeVec(0, vid);
  }

  std::vector<IdArray> GetAdj(
181
      dgl_type_t etype, bool transpose, const std::string& fmt) const override {
182
183
184
    return GetRelationGraph(etype)->GetAdj(0, transpose, fmt);
  }

185
186
187
188
189
190
191
192
193
194
195
196
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    return GetRelationGraph(etype)->GetCOOMatrix(0);
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    return GetRelationGraph(etype)->GetCSCMatrix(0);
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    return GetRelationGraph(etype)->GetCSRMatrix(0);
  }

197
198
  SparseFormat SelectFormat(
      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
199
    return GetRelationGraph(etype)->SelectFormat(0, preferred_formats);
200
201
  }

202
203
  dgl_format_code_t GetAllowedFormats() const override {
    return GetRelationGraph(0)->GetAllowedFormats();
204
205
  }

206
207
  dgl_format_code_t GetCreatedFormats() const override {
    return GetRelationGraph(0)->GetCreatedFormats();
208
209
  }

210
211
  HeteroSubgraph VertexSubgraph(
      const std::vector<IdArray>& vids) const override;
212
213

  HeteroSubgraph EdgeSubgraph(
214
215
      const std::vector<IdArray>& eids,
      bool preserve_nodes = false) const override;
216

217
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;
218

219
220
  FlattenedHeteroGraphPtr Flatten(
      const std::vector<dgl_type_t>& etypes) const override;
Minjie Wang's avatar
Minjie Wang committed
221

222
223
  GraphPtr AsImmutableGraph() const override;

224
  /** @return Load HeteroGraph from stream, using CSRMatrix*/
225
226
  bool Load(dmlc::Stream* fs);

227
  /** @return Save HeteroGraph to stream, using CSRMatrix */
228
229
  void Save(dmlc::Stream* fs) const;

230
  /** @brief Convert the graph to use the given number of bits for storage */
231
232
  static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);

233
  /** @brief Copy the data to another context */
234
  static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext& ctx);
235

236
  /**
237
238
239
240
241
242
   * @brief Pin all relation graphs of the current graph.
   * @note The graph will be pinned inplace. Behavior depends on the current
   * context, kDGLCPU: will be pinned; IsPinned: directly return; kDGLCUDA:
   * invalid, will throw an error. The context check is deferred to pinning the
   * NDArray.
   */
243
  void PinMemory_() override;
244

245
  /**
246
247
248
249
250
   * @brief Unpin all relation graphs of the current graph.
   * @note The graph will be unpinned inplace. Behavior depends on the current
   * context, IsPinned: will be unpinned; others: directly return. The context
   * check is deferred to unpinning the NDArray.
   */
251
252
  void UnpinMemory_();

253
254
255
256
257
258
259
260
261
262
  /**
   * @brief Copy the current graph to pinned memory managed by
   *     PyTorch CachingHostAllocator for each relation graph.
   * @note If any of the underlying relation graphs are already pinned, the
   *     function will utilize their existing copies. If all of them are
   *     pinned, the function will return the original input hetero-graph
   *     directly.
   */
  static HeteroGraphPtr PinMemory(HeteroGraphPtr g);

263
  /**
264
265
   * @brief Record stream for this graph.
   * @param stream The stream that is using the graph
266
267
268
   */
  void RecordStream(DGLStreamHandle stream) override;

269
270
  /**
   * @brief Copy the data to shared memory.
271
272
273
274
   *
   * Also save names of node types and edge types of the HeteroGraph object to
   * shared memory
   */
275
  static HeteroGraphPtr CopyToSharedMem(
276
277
278
279
      HeteroGraphPtr g, const std::string& name,
      const std::vector<std::string>& ntypes,
      const std::vector<std::string>& etypes,
      const std::set<std::string>& fmts);
280

281
282
283
284
  /**
   * @brief Create a heterograph from
   *
   * @return the HeteroGraphPtr, names of node types, names of edge types
285
286
287
288
   */
  static std::tuple<
      HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
  CreateFromSharedMem(const std::string& name);
289

290
  /** @brief Creat a LineGraph of self */
291
292
  HeteroGraphPtr LineGraph(bool backtracking) const;

293
294
295
296
  const std::vector<UnitGraphPtr>& relation_graphs() const {
    return relation_graphs_;
  }

297
 private:
298
299
300
301
  // To create empty class
  friend class Serializer;

  // Empty Constructor, only for serializer
302
  HeteroGraph() : BaseHeteroGraph() {}
303

304
  /** @brief A map from edge type to unit graph */
305
  std::vector<UnitGraphPtr> relation_graphs_;
306

307
  /** @brief A map from vert type to the number of verts in the type */
308
  std::vector<int64_t> num_verts_per_type_;
309

310
  /** @brief The shared memory object for meta info*/
311
312
  std::shared_ptr<runtime::SharedMemory> shared_mem_;

313
314
315
  /**
   * @brief The name of the shared memory. Return empty string if it is not in
   * shared memory.
316
   */
317
318
  std::string SharedMemName() const;

319
320
  /**
   * @brief template class for Flatten operation
321
322
323
324
325
   *
   * @tparam IdType Graph's index data type, can be int32_t or int64_t
   * @param etypes vector of etypes to be falttened
   * @return pointer of FlattenedHeteroGraphh
   */
326
  template <class IdType>
327
328
  FlattenedHeteroGraphPtr FlattenImpl(
      const std::vector<dgl_type_t>& etypes) const;
329
330
331
332
};

}  // namespace dgl

333
334
335
336
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::HeteroGraph, true);
}  // namespace dmlc

337
#endif  // DGL_GRAPH_HETEROGRAPH_H_