graph.cc 22.2 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/graph.cc
 * \brief DGL graph index implementation
 */
#include <dgl/graph.h>
Da Zheng's avatar
Da Zheng committed
7
#include <dgl/sampler.h>
Minjie Wang's avatar
impl  
Minjie Wang committed
8
#include <algorithm>
Minjie Wang's avatar
Minjie Wang committed
9
#include <unordered_map>
10
11
#include <set>
#include <functional>
GaiYu0's avatar
GaiYu0 committed
12
13
#include <tuple>
#include "../c_api_common.h"
Minjie Wang's avatar
Minjie Wang committed
14

Minjie Wang's avatar
impl  
Minjie Wang committed
15
namespace dgl {
16

17
Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes) {
18
19
  CHECK(aten::IsValidIdArray(src_ids));
  CHECK(aten::IsValidIdArray(dst_ids));
20
21
  this->AddVertices(num_nodes);
  num_edges_ = src_ids->shape[0];
22
23
  CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0])
    << "vectors in COO must have the same length";
24
25
26
27
  const dgl_id_t *src_data = static_cast<dgl_id_t*>(src_ids->data);
  const dgl_id_t *dst_data = static_cast<dgl_id_t*>(dst_ids->data);
  all_edges_src_.reserve(num_edges_);
  all_edges_dst_.reserve(num_edges_);
28
  for (uint64_t i = 0; i < num_edges_; i++) {
29
30
31
32
33
34
    auto src = src_data[i];
    auto dst = dst_data[i];
    CHECK(HasVertex(src) && HasVertex(dst))
      << "Invalid vertices: src=" << src << " dst=" << dst;

    adjlist_[src].succ.push_back(dst);
35
    adjlist_[src].edge_id.push_back(i);
36
    reverse_adjlist_[dst].succ.push_back(src);
37
    reverse_adjlist_[dst].edge_id.push_back(i);
38
39
40
41
42
43

    all_edges_src_.push_back(src);
    all_edges_dst_.push_back(dst);
  }
}

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
bool Graph::IsMultigraph() const {
  if (num_edges_ <= 1) {
    return false;
  }

  typedef std::pair<int64_t, int64_t> Pair;
  std::vector<Pair> pairs;
  pairs.reserve(num_edges_);
  for (uint64_t eid = 0; eid < num_edges_; ++eid) {
    pairs.emplace_back(all_edges_src_[eid], all_edges_dst_[eid]);
  }
  // sort according to src and dst ids
  std::sort(pairs.begin(), pairs.end(),
      [] (const Pair& t1, const Pair& t2) {
        return std::get<0>(t1) < std::get<0>(t2)
          || (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
      });
  for (uint64_t eid = 0; eid < num_edges_-1; ++eid) {
    // As src and dst are all sorted, we only need to compare i and i+1
    if (std::get<0>(pairs[eid]) == std::get<0>(pairs[eid+1]) &&
        std::get<1>(pairs[eid]) == std::get<1>(pairs[eid+1]))
        return true;
  }

  return false;
}

Minjie Wang's avatar
impl  
Minjie Wang committed
71
72
73
void Graph::AddVertices(uint64_t num_vertices) {
  CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
  adjlist_.resize(adjlist_.size() + num_vertices);
Minjie Wang's avatar
Minjie Wang committed
74
  reverse_adjlist_.resize(reverse_adjlist_.size() + num_vertices);
Minjie Wang's avatar
impl  
Minjie Wang committed
75
76
77
78
79
}

void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
  CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
  CHECK(HasVertex(src) && HasVertex(dst))
80
    << "Invalid vertices: src=" << src << " dst=" << dst;
81

Minjie Wang's avatar
impl  
Minjie Wang committed
82
  dgl_id_t eid = num_edges_++;
83

Minjie Wang's avatar
impl  
Minjie Wang committed
84
  adjlist_[src].succ.push_back(dst);
Minjie Wang's avatar
Minjie Wang committed
85
86
87
  adjlist_[src].edge_id.push_back(eid);
  reverse_adjlist_[dst].succ.push_back(src);
  reverse_adjlist_[dst].edge_id.push_back(eid);
88

Minjie Wang's avatar
Minjie Wang committed
89
90
  all_edges_src_.push_back(src);
  all_edges_dst_.push_back(dst);
Minjie Wang's avatar
impl  
Minjie Wang committed
91
92
93
94
}

void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
  CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
95
96
  CHECK(aten::IsValidIdArray(src_ids)) << "Invalid src id array.";
  CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid dst id array.";
Minjie Wang's avatar
impl  
Minjie Wang committed
97
  const auto srclen = src_ids->shape[0];
Minjie Wang's avatar
Minjie Wang committed
98
  const auto dstlen = dst_ids->shape[0];
Minjie Wang's avatar
impl  
Minjie Wang committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
  const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
  if (srclen == 1) {
    // one-many
    for (int64_t i = 0; i < dstlen; ++i) {
      AddEdge(src_data[0], dst_data[i]);
    }
  } else if (dstlen == 1) {
    // many-one
    for (int64_t i = 0; i < srclen; ++i) {
      AddEdge(src_data[i], dst_data[0]);
    }
  } else {
    // many-many
    CHECK(srclen == dstlen) << "Invalid src and dst id array.";
    for (int64_t i = 0; i < srclen; ++i) {
      AddEdge(src_data[i], dst_data[i]);
    }
  }
}

BoolArray Graph::HasVertices(IdArray vids) const {
121
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
Minjie Wang's avatar
impl  
Minjie Wang committed
122
123
124
125
  const auto len = vids->shape[0];
  BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);
  const int64_t* vid_data = static_cast<int64_t*>(vids->data);
  int64_t* rst_data = static_cast<int64_t*>(rst->data);
Minjie Wang's avatar
Minjie Wang committed
126
  const int64_t nverts = NumVertices();
Minjie Wang's avatar
impl  
Minjie Wang committed
127
  for (int64_t i = 0; i < len; ++i) {
128
    rst_data[i] = (vid_data[i] < nverts && vid_data[i] >= 0)? 1 : 0;
Minjie Wang's avatar
impl  
Minjie Wang committed
129
130
131
132
133
  }
  return rst;
}

// O(E)
134
bool Graph::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
Minjie Wang's avatar
impl  
Minjie Wang committed
135
136
137
138
139
  if (!HasVertex(src) || !HasVertex(dst)) return false;
  const auto& succ = adjlist_[src].succ;
  return std::find(succ.begin(), succ.end(), dst) != succ.end();
}

140
141
// O(E*k) pretty slow
BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
142
143
  CHECK(aten::IsValidIdArray(src_ids)) << "Invalid src id array.";
  CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid dst id array.";
Minjie Wang's avatar
impl  
Minjie Wang committed
144
  const auto srclen = src_ids->shape[0];
Minjie Wang's avatar
Minjie Wang committed
145
  const auto dstlen = dst_ids->shape[0];
Minjie Wang's avatar
impl  
Minjie Wang committed
146
147
148
149
150
151
152
153
  const auto rstlen = std::max(srclen, dstlen);
  BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
  int64_t* rst_data = static_cast<int64_t*>(rst->data);
  const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
  const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
  if (srclen == 1) {
    // one-many
    for (int64_t i = 0; i < dstlen; ++i) {
154
      rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i])? 1 : 0;
Minjie Wang's avatar
impl  
Minjie Wang committed
155
156
157
158
    }
  } else if (dstlen == 1) {
    // many-one
    for (int64_t i = 0; i < srclen; ++i) {
159
      rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0])? 1 : 0;
Minjie Wang's avatar
impl  
Minjie Wang committed
160
161
162
163
164
    }
  } else {
    // many-many
    CHECK(srclen == dstlen) << "Invalid src and dst id array.";
    for (int64_t i = 0; i < srclen; ++i) {
165
      rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i])? 1 : 0;
Minjie Wang's avatar
impl  
Minjie Wang committed
166
167
168
169
170
171
    }
  }
  return rst;
}

// The data is copy-out; support zero-copy?
Minjie Wang's avatar
Minjie Wang committed
172
IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
Minjie Wang's avatar
impl  
Minjie Wang committed
173
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
Minjie Wang's avatar
Minjie Wang committed
174
  CHECK(radius >= 1) << "invalid radius: " << radius;
175
176
177
178
179
180
  std::set<dgl_id_t> vset;

  for (auto& it : reverse_adjlist_[vid].succ)
    vset.insert(it);

  const int64_t len = vset.size();
Minjie Wang's avatar
impl  
Minjie Wang committed
181
182
  IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  int64_t* rst_data = static_cast<int64_t*>(rst->data);
183
184

  std::copy(vset.begin(), vset.end(), rst_data);
Minjie Wang's avatar
impl  
Minjie Wang committed
185
186
187
188
  return rst;
}

// The data is copy-out; support zero-copy?
Minjie Wang's avatar
Minjie Wang committed
189
IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
Minjie Wang's avatar
impl  
Minjie Wang committed
190
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
Minjie Wang's avatar
Minjie Wang committed
191
  CHECK(radius >= 1) << "invalid radius: " << radius;
192
193
194
195
196
197
  std::set<dgl_id_t> vset;

  for (auto& it : adjlist_[vid].succ)
    vset.insert(it);

  const int64_t len = vset.size();
Minjie Wang's avatar
impl  
Minjie Wang committed
198
199
  IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  int64_t* rst_data = static_cast<int64_t*>(rst->data);
200
201

  std::copy(vset.begin(), vset.end(), rst_data);
Minjie Wang's avatar
impl  
Minjie Wang committed
202
203
204
205
  return rst;
}

// O(E)
206
207
208
IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src) && HasVertex(dst)) << "invalid edge: " << src << " -> " << dst;

Minjie Wang's avatar
impl  
Minjie Wang committed
209
  const auto& succ = adjlist_[src].succ;
210
211
  std::vector<dgl_id_t> edgelist;

Minjie Wang's avatar
impl  
Minjie Wang committed
212
  for (size_t i = 0; i < succ.size(); ++i) {
213
214
    if (succ[i] == dst)
      edgelist.push_back(adjlist_[src].edge_id[i]);
Minjie Wang's avatar
impl  
Minjie Wang committed
215
  }
216
217
218
219
220
221
222
223
224
225

  // FIXME: signed?  Also it seems that we are using int64_t everywhere...
  const int64_t len = edgelist.size();
  IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  // FIXME: signed?
  int64_t* rst_data = static_cast<int64_t*>(rst->data);

  std::copy(edgelist.begin(), edgelist.end(), rst_data);

  return rst;
Minjie Wang's avatar
impl  
Minjie Wang committed
226
227
228
}

// O(E*k) pretty slow
229
EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
230
231
  CHECK(aten::IsValidIdArray(src_ids)) << "Invalid src id array.";
  CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid dst id array.";
Minjie Wang's avatar
impl  
Minjie Wang committed
232
  const auto srclen = src_ids->shape[0];
Minjie Wang's avatar
Minjie Wang committed
233
  const auto dstlen = dst_ids->shape[0];
234
235
236
237
238
239
240
  int64_t i, j;

  CHECK((srclen == dstlen) || (srclen == 1) || (dstlen == 1))
    << "Invalid src and dst id array.";

  const int64_t src_stride = (srclen == 1 && dstlen != 1) ? 0 : 1;
  const int64_t dst_stride = (dstlen == 1 && srclen != 1) ? 0 : 1;
Minjie Wang's avatar
impl  
Minjie Wang committed
241
242
  const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
  const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
243
244
245
246
247

  std::vector<dgl_id_t> src, dst, eid;

  for (i = 0, j = 0; i < srclen && j < dstlen; i += src_stride, j += dst_stride) {
    const dgl_id_t src_id = src_data[i], dst_id = dst_data[j];
248
249
    CHECK(HasVertex(src_id) && HasVertex(dst_id)) <<
        "invalid edge: " << src_id << " -> " << dst_id;
250
251
252
    const auto& succ = adjlist_[src_id].succ;
    for (size_t k = 0; k < succ.size(); ++k) {
      if (succ[k] == dst_id) {
253
254
255
        src.push_back(src_id);
        dst.push_back(dst_id);
        eid.push_back(adjlist_[src_id].edge_id[k]);
256
      }
Minjie Wang's avatar
impl  
Minjie Wang committed
257
258
    }
  }
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

  int64_t rstlen = src.size();
  IdArray rst_src = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
  IdArray rst_dst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
  IdArray rst_eid = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
  int64_t* rst_src_data = static_cast<int64_t*>(rst_src->data);
  int64_t* rst_dst_data = static_cast<int64_t*>(rst_dst->data);
  int64_t* rst_eid_data = static_cast<int64_t*>(rst_eid->data);

  std::copy(src.begin(), src.end(), rst_src_data);
  std::copy(dst.begin(), dst.end(), rst_dst_data);
  std::copy(eid.begin(), eid.end(), rst_eid_data);

  return EdgeArray{rst_src, rst_dst, rst_eid};
}

275
EdgeArray Graph::FindEdges(IdArray eids) const {
276
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
  int64_t len = eids->shape[0];

  IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx);
  IdArray rst_dst = IdArray::Empty({len}, eids->dtype, eids->ctx);
  IdArray rst_eid = IdArray::Empty({len}, eids->dtype, eids->ctx);
  int64_t* eid_data = static_cast<int64_t*>(eids->data);
  int64_t* rst_src_data = static_cast<int64_t*>(rst_src->data);
  int64_t* rst_dst_data = static_cast<int64_t*>(rst_dst->data);
  int64_t* rst_eid_data = static_cast<int64_t*>(rst_eid->data);

  for (uint64_t i = 0; i < (uint64_t)len; ++i) {
    dgl_id_t eid = eid_data[i];
    if (eid >= num_edges_)
      LOG(FATAL) << "invalid edge id:" << eid;

    rst_src_data[i] = all_edges_src_[eid];
    rst_dst_data[i] = all_edges_dst_[eid];
    rst_eid_data[i] = eid;
  }

  return EdgeArray{rst_src, rst_dst, rst_eid};
Minjie Wang's avatar
impl  
Minjie Wang committed
298
299
300
}

// O(E)
301
EdgeArray Graph::InEdges(dgl_id_t vid) const {
302
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
Minjie Wang's avatar
Minjie Wang committed
303
  const int64_t len = reverse_adjlist_[vid].succ.size();
304
305
306
307
  IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  int64_t* src_data = static_cast<int64_t*>(src->data);
Minjie Wang's avatar
impl  
Minjie Wang committed
308
  int64_t* dst_data = static_cast<int64_t*>(dst->data);
309
310
  int64_t* eid_data = static_cast<int64_t*>(eid->data);
  for (int64_t i = 0; i < len; ++i) {
Minjie Wang's avatar
Minjie Wang committed
311
312
    src_data[i] = reverse_adjlist_[vid].succ[i];
    eid_data[i] = reverse_adjlist_[vid].edge_id[i];
313
314
315
  }
  std::fill(dst_data, dst_data + len, vid);
  return EdgeArray{src, dst, eid};
Minjie Wang's avatar
impl  
Minjie Wang committed
316
317
318
}

// O(E)
319
EdgeArray Graph::InEdges(IdArray vids) const {
320
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
Minjie Wang's avatar
impl  
Minjie Wang committed
321
322
323
324
325
  const auto len = vids->shape[0];
  const int64_t* vid_data = static_cast<int64_t*>(vids->data);
  int64_t rstlen = 0;
  for (int64_t i = 0; i < len; ++i) {
    CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
Minjie Wang's avatar
Minjie Wang committed
326
    rstlen += reverse_adjlist_[vid_data[i]].succ.size();
Minjie Wang's avatar
impl  
Minjie Wang committed
327
328
329
  }
  IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
  IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
330
  IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
Minjie Wang's avatar
impl  
Minjie Wang committed
331
332
  int64_t* src_ptr = static_cast<int64_t*>(src->data);
  int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
333
  int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
Minjie Wang's avatar
impl  
Minjie Wang committed
334
  for (int64_t i = 0; i < len; ++i) {
Minjie Wang's avatar
Minjie Wang committed
335
336
    const auto& pred = reverse_adjlist_[vid_data[i]].succ;
    const auto& eids = reverse_adjlist_[vid_data[i]].edge_id;
Minjie Wang's avatar
impl  
Minjie Wang committed
337
338
339
    for (size_t j = 0; j < pred.size(); ++j) {
      *(src_ptr++) = pred[j];
      *(dst_ptr++) = vid_data[i];
340
      *(eid_ptr++) = eids[j];
Minjie Wang's avatar
impl  
Minjie Wang committed
341
342
    }
  }
343
  return EdgeArray{src, dst, eid};
Minjie Wang's avatar
impl  
Minjie Wang committed
344
345
346
}

// O(E)
347
EdgeArray Graph::OutEdges(dgl_id_t vid) const {
348
349
350
351
352
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
  const int64_t len = adjlist_[vid].succ.size();
  IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
Minjie Wang's avatar
impl  
Minjie Wang committed
353
  int64_t* src_data = static_cast<int64_t*>(src->data);
354
355
356
357
  int64_t* dst_data = static_cast<int64_t*>(dst->data);
  int64_t* eid_data = static_cast<int64_t*>(eid->data);
  for (int64_t i = 0; i < len; ++i) {
    dst_data[i] = adjlist_[vid].succ[i];
Minjie Wang's avatar
Minjie Wang committed
358
    eid_data[i] = adjlist_[vid].edge_id[i];
359
360
361
  }
  std::fill(src_data, src_data + len, vid);
  return EdgeArray{src, dst, eid};
Minjie Wang's avatar
impl  
Minjie Wang committed
362
363
364
}

// O(E)
365
EdgeArray Graph::OutEdges(IdArray vids) const {
366
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
Minjie Wang's avatar
impl  
Minjie Wang committed
367
368
369
370
371
372
373
374
375
  const auto len = vids->shape[0];
  const int64_t* vid_data = static_cast<int64_t*>(vids->data);
  int64_t rstlen = 0;
  for (int64_t i = 0; i < len; ++i) {
    CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
    rstlen += adjlist_[vid_data[i]].succ.size();
  }
  IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
  IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
376
  IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
Minjie Wang's avatar
impl  
Minjie Wang committed
377
378
  int64_t* src_ptr = static_cast<int64_t*>(src->data);
  int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
379
  int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
Minjie Wang's avatar
impl  
Minjie Wang committed
380
381
  for (int64_t i = 0; i < len; ++i) {
    const auto& succ = adjlist_[vid_data[i]].succ;
Minjie Wang's avatar
Minjie Wang committed
382
    const auto& eids = adjlist_[vid_data[i]].edge_id;
Minjie Wang's avatar
impl  
Minjie Wang committed
383
384
385
    for (size_t j = 0; j < succ.size(); ++j) {
      *(src_ptr++) = vid_data[i];
      *(dst_ptr++) = succ[j];
386
      *(eid_ptr++) = eids[j];
Minjie Wang's avatar
impl  
Minjie Wang committed
387
388
    }
  }
389
  return EdgeArray{src, dst, eid};
Minjie Wang's avatar
impl  
Minjie Wang committed
390
391
}

Minjie Wang's avatar
Minjie Wang committed
392
// O(E*log(E)) if sort is required; otherwise, O(E)
393
EdgeArray Graph::Edges(const std::string &order) const {
Minjie Wang's avatar
impl  
Minjie Wang committed
394
395
396
  const int64_t len = num_edges_;
  IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
  IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
397
  IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
Minjie Wang's avatar
Minjie Wang committed
398

399
  if (order == "srcdst") {
Minjie Wang's avatar
Minjie Wang committed
400
401
402
    typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
    std::vector<Tuple> tuples;
    tuples.reserve(len);
Minjie Wang's avatar
Minjie Wang committed
403
404
    for (uint64_t eid = 0; eid < num_edges_; ++eid) {
      tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid);
Minjie Wang's avatar
Minjie Wang committed
405
    }
Minjie Wang's avatar
Minjie Wang committed
406
    // sort according to src and dst ids
Minjie Wang's avatar
Minjie Wang committed
407
408
    std::sort(tuples.begin(), tuples.end(),
        [] (const Tuple& t1, const Tuple& t2) {
Minjie Wang's avatar
Minjie Wang committed
409
410
          return std::get<0>(t1) < std::get<0>(t2)
            || (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
Minjie Wang's avatar
Minjie Wang committed
411
        });
412

Minjie Wang's avatar
Minjie Wang committed
413
414
415
    // make return arrays
    int64_t* src_ptr = static_cast<int64_t*>(src->data);
    int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
416
    int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
Minjie Wang's avatar
Minjie Wang committed
417
    for (size_t i = 0; i < tuples.size(); ++i) {
Minjie Wang's avatar
Minjie Wang committed
418
419
      src_ptr[i] = std::get<0>(tuples[i]);
      dst_ptr[i] = std::get<1>(tuples[i]);
420
      eid_ptr[i] = std::get<2>(tuples[i]);
Minjie Wang's avatar
Minjie Wang committed
421
422
423
424
    }
  } else {
    int64_t* src_ptr = static_cast<int64_t*>(src->data);
    int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
425
    int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
Minjie Wang's avatar
Minjie Wang committed
426
427
428
429
    std::copy(all_edges_src_.begin(), all_edges_src_.end(), src_ptr);
    std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), dst_ptr);
    for (uint64_t eid = 0; eid < num_edges_; ++eid) {
      eid_ptr[eid] = eid;
Minjie Wang's avatar
Minjie Wang committed
430
    }
Minjie Wang's avatar
impl  
Minjie Wang committed
431
432
  }

433
  return EdgeArray{src, dst, eid};
Minjie Wang's avatar
impl  
Minjie Wang committed
434
435
436
437
}

// O(V)
DegreeArray Graph::InDegrees(IdArray vids) const {
438
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
Minjie Wang's avatar
impl  
Minjie Wang committed
439
440
441
442
443
444
445
  const auto len = vids->shape[0];
  const int64_t* vid_data = static_cast<int64_t*>(vids->data);
  DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
  int64_t* rst_data = static_cast<int64_t*>(rst->data);
  for (int64_t i = 0; i < len; ++i) {
    const auto vid = vid_data[i];
    CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
Minjie Wang's avatar
Minjie Wang committed
446
    rst_data[i] = reverse_adjlist_[vid].succ.size();
Minjie Wang's avatar
impl  
Minjie Wang committed
447
448
449
450
451
452
  }
  return rst;
}

// O(V)
DegreeArray Graph::OutDegrees(IdArray vids) const {
453
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
Minjie Wang's avatar
impl  
Minjie Wang committed
454
455
456
457
458
459
460
461
462
463
464
465
  const auto len = vids->shape[0];
  const int64_t* vid_data = static_cast<int64_t*>(vids->data);
  DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
  int64_t* rst_data = static_cast<int64_t*>(rst->data);
  for (int64_t i = 0; i < len; ++i) {
    const auto vid = vid_data[i];
    CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
    rst_data[i] = adjlist_[vid].succ.size();
  }
  return rst;
}

Minjie Wang's avatar
Minjie Wang committed
466
Subgraph Graph::VertexSubgraph(IdArray vids) const {
467
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
Minjie Wang's avatar
Minjie Wang committed
468
469
470
471
472
473
474
475
  const auto len = vids->shape[0];
  std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
  std::vector<dgl_id_t> edges;
  const int64_t* vid_data = static_cast<int64_t*>(vids->data);
  for (int64_t i = 0; i < len; ++i) {
    oldv2newv[vid_data[i]] = i;
  }
  Subgraph rst;
476
  rst.graph = std::make_shared<Graph>();
Minjie Wang's avatar
Minjie Wang committed
477
  rst.induced_vertices = vids;
478
  rst.graph->AddVertices(len);
Minjie Wang's avatar
Minjie Wang committed
479
480
481
482
483
484
485
486
  for (int64_t i = 0; i < len; ++i) {
    const dgl_id_t oldvid = vid_data[i];
    const dgl_id_t newvid = i;
    for (size_t j = 0; j < adjlist_[oldvid].succ.size(); ++j) {
      const dgl_id_t oldsucc = adjlist_[oldvid].succ[j];
      if (oldv2newv.count(oldsucc)) {
        const dgl_id_t newsucc = oldv2newv[oldsucc];
        edges.push_back(adjlist_[oldvid].edge_id[j]);
487
        rst.graph->AddEdge(newvid, newsucc);
Minjie Wang's avatar
Minjie Wang committed
488
489
490
491
492
493
      }
    }
  }
  rst.induced_edges = IdArray::Empty({static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx);
  std::copy(edges.begin(), edges.end(), static_cast<int64_t*>(rst.induced_edges->data));
  return rst;
Minjie Wang's avatar
impl  
Minjie Wang committed
494
}
Minjie Wang's avatar
Minjie Wang committed
495

496
Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
497
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
498
499
500
501
502
  const auto len = eids->shape[0];
  std::vector<dgl_id_t> nodes;
  const int64_t* eid_data = static_cast<int64_t*>(eids->data);

  Subgraph rst;
503
504
505
506
507
508
509
510
511
512
513
  if (!preserve_nodes) {
    std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;

    for (int64_t i = 0; i < len; ++i) {
      const dgl_id_t src_id = all_edges_src_[eid_data[i]];
      const dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
      if (oldv2newv.insert(std::make_pair(src_id, oldv2newv.size())).second)
        nodes.push_back(src_id);
      if (oldv2newv.insert(std::make_pair(dst_id, oldv2newv.size())).second)
        nodes.push_back(dst_id);
    }
514

515
    rst.graph = std::make_shared<Graph>();
516
517
518
519
520
521
522
523
524
525
526
527
528
    rst.induced_edges = eids;
    rst.graph->AddVertices(nodes.size());

    for (int64_t i = 0; i < len; ++i) {
      const dgl_id_t src_id = all_edges_src_[eid_data[i]];
      const dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
      rst.graph->AddEdge(oldv2newv[src_id], oldv2newv[dst_id]);
    }

    rst.induced_vertices = IdArray::Empty(
        {static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
    std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
  } else {
529
    rst.graph = std::make_shared<Graph>();
530
531
532
533
534
535
536
537
    rst.induced_edges = eids;
    rst.graph->AddVertices(NumVertices());

    for (int64_t i = 0; i < len; ++i) {
      dgl_id_t src_id = all_edges_src_[eid_data[i]];
      dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
      rst.graph->AddEdge(src_id, dst_id);
    }
538

539
    for (uint64_t i = 0; i < NumVertices(); ++i)
540
541
542
543
544
545
      nodes.push_back(i);

    rst.induced_vertices = IdArray::Empty(
        {static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
    std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
  }
546
547

  return rst;
Minjie Wang's avatar
Minjie Wang committed
548
549
}

550
std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const {
551
552
  uint64_t num_edges = NumEdges();
  uint64_t num_nodes = NumVertices();
553
  if (fmt == "coo") {
554
555
556
557
    IdArray idx = IdArray::Empty(
        {2 * static_cast<int64_t>(num_edges)},
        DLDataType{kDLInt, 64, 1},
        DLContext{kDLCPU, 0});
558
559
560
561
562
563
564
565
    int64_t *idx_data = static_cast<int64_t*>(idx->data);
    if (transpose) {
      std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data);
      std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data + num_edges);
    } else {
      std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data);
      std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges);
    }
566
567
568
569
    IdArray eid = IdArray::Empty(
        {static_cast<int64_t>(num_edges)},
        DLDataType{kDLInt, 64, 1},
        DLContext{kDLCPU, 0});
570
571
572
573
574
575
    int64_t *eid_data = static_cast<int64_t*>(eid->data);
    for (uint64_t eid = 0; eid < num_edges; ++eid) {
      eid_data[eid] = eid;
    }
    return std::vector<IdArray>{idx, eid};
  } else if (fmt == "csr") {
576
577
578
579
580
581
582
583
584
585
586
587
    IdArray indptr = IdArray::Empty(
        {static_cast<int64_t>(num_nodes) + 1},
        DLDataType{kDLInt, 64, 1},
        DLContext{kDLCPU, 0});
    IdArray indices = IdArray::Empty(
        {static_cast<int64_t>(num_edges)},
        DLDataType{kDLInt, 64, 1},
        DLContext{kDLCPU, 0});
    IdArray eid = IdArray::Empty(
        {static_cast<int64_t>(num_edges)},
        DLDataType{kDLInt, 64, 1},
        DLContext{kDLCPU, 0});
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
    int64_t *indptr_data = static_cast<int64_t*>(indptr->data);
    int64_t *indices_data = static_cast<int64_t*>(indices->data);
    int64_t *eid_data = static_cast<int64_t*>(eid->data);
    const AdjacencyList *adjlist;
    if (transpose) {
      // Out-edges.
      adjlist = &adjlist_;
    } else {
      // In-edges.
      adjlist = &reverse_adjlist_;
    }
    indptr_data[0] = 0;
    for (size_t i = 0; i < adjlist->size(); i++) {
      indptr_data[i + 1] = indptr_data[i] + adjlist->at(i).succ.size();
      std::copy(adjlist->at(i).succ.begin(), adjlist->at(i).succ.end(),
                indices_data + indptr_data[i]);
      std::copy(adjlist->at(i).edge_id.begin(), adjlist->at(i).edge_id.end(),
                eid_data + indptr_data[i]);
    }
    return std::vector<IdArray>{indptr, indices, eid};
  } else {
    LOG(FATAL) << "unsupported format";
    return std::vector<IdArray>();
  }
}

Minjie Wang's avatar
impl  
Minjie Wang committed
614
}  // namespace dgl