union_partition.cc 6.83 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/coo_union_partition.cc
 * \brief COO union and partition
 */
#include <dgl/array.h>
7

8
9
10
11
12
13
#include <vector>

namespace dgl {
namespace aten {
///////////////////////// COO Based Operations/////////////////////////
std::vector<COOMatrix> DisjointPartitionCooBySizes(
14
15
16
17
    const COOMatrix &coo, const uint64_t batch_size,
    const std::vector<uint64_t> &edge_cumsum,
    const std::vector<uint64_t> &src_vertex_cumsum,
    const std::vector<uint64_t> &dst_vertex_cumsum) {
18
19
20
21
22
23
24
  CHECK_EQ(edge_cumsum.size(), batch_size + 1);
  CHECK_EQ(src_vertex_cumsum.size(), batch_size + 1);
  CHECK_EQ(dst_vertex_cumsum.size(), batch_size + 1);
  std::vector<COOMatrix> ret;
  ret.resize(batch_size);

  for (size_t g = 0; g < batch_size; ++g) {
25
26
27
28
29
30
    IdArray result_src =
        IndexSelect(coo.row, edge_cumsum[g], edge_cumsum[g + 1]) -
        src_vertex_cumsum[g];
    IdArray result_dst =
        IndexSelect(coo.col, edge_cumsum[g], edge_cumsum[g + 1]) -
        dst_vertex_cumsum[g];
31
32
33
    IdArray result_data = NullArray();
    // has data index array
    if (COOHasData(coo)) {
34
35
      result_data = IndexSelect(coo.data, edge_cumsum[g], edge_cumsum[g + 1]) -
                    edge_cumsum[g];
36
37
38
    }

    COOMatrix sub_coo = COOMatrix(
39
40
41
        src_vertex_cumsum[g + 1] - src_vertex_cumsum[g],
        dst_vertex_cumsum[g + 1] - dst_vertex_cumsum[g], result_src, result_dst,
        result_data, coo.row_sorted, coo.col_sorted);
42
43
44
45
46
47
    ret[g] = sub_coo;
  }

  return ret;
}

48
COOMatrix COOSliceContiguousChunk(
49
50
51
    const COOMatrix &coo, const std::vector<uint64_t> &edge_range,
    const std::vector<uint64_t> &src_vertex_range,
    const std::vector<uint64_t> &dst_vertex_range) {
52
53
54
55
  IdArray result_src = NullArray(coo.row->dtype, coo.row->ctx);
  IdArray result_dst = NullArray(coo.row->dtype, coo.row->ctx);
  if (edge_range[1] != edge_range[0]) {
    // The chunk has edges
56
57
58
59
    result_src = IndexSelect(coo.row, edge_range[0], edge_range[1]) -
                 src_vertex_range[0];
    result_dst = IndexSelect(coo.col, edge_range[0], edge_range[1]) -
                 dst_vertex_range[0];
60
61
62
63
64
  }

  IdArray result_data = NullArray();
  // has data index array
  if (COOHasData(coo)) {
65
66
    result_data =
        IndexSelect(coo.data, edge_range[0], edge_range[1]) - edge_range[0];
67
68
69
  }

  COOMatrix sub_coo = COOMatrix(
70
71
72
      src_vertex_range[1] - src_vertex_range[0],
      dst_vertex_range[1] - dst_vertex_range[0], result_src, result_dst,
      result_data, coo.row_sorted, coo.col_sorted);
73
74
75
76

  return sub_coo;
}

77
///////////////////////// CSR Based Operations/////////////////////////
78
CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix> &csrs) {
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
  uint64_t src_offset = 0, dst_offset = 0;
  int64_t indices_offset = 0;
  bool has_data = false;
  bool sorted = true;

  // check if data index array
  for (size_t i = 0; i < csrs.size(); ++i) {
    CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);
    CHECK_SAME_CONTEXT(csrs[0].indices, csrs[i].indices);
    has_data |= CSRHasData(csrs[i]);
  }

  std::vector<IdArray> res_indptr;
  std::vector<IdArray> res_indices;
  std::vector<IdArray> res_data;
  res_indptr.resize(csrs.size());
  res_indices.resize(csrs.size());

  for (size_t i = 0; i < csrs.size(); ++i) {
    const aten::CSRMatrix &csr = csrs[i];
    sorted &= csr.sorted;
    IdArray indptr = csr.indptr + indices_offset;
    IdArray indices = csr.indices + dst_offset;
102
    if (i > 0) indptr = IndexSelect(indptr, 1, indptr->shape[0]);
103
104
105
106
107
108
109
110
111
    res_indptr[i] = indptr;
    res_indices[i] = indices;
    src_offset += csr.num_rows;
    dst_offset += csr.num_cols;

    // any one of input csr has data index array
    if (has_data) {
      IdArray edges_data;
      if (CSRHasData(csr) == false) {
112
113
114
        edges_data = Range(
            indices_offset, indices_offset + csr.indices->shape[0],
            csr.indices->dtype.bits, csr.indices->ctx);
115
116
117
118
119
120
121
122
123
124
125
126
127
      } else {
        edges_data = csr.data + indices_offset;
      }
      res_data.push_back(edges_data);
      indices_offset += csr.indices->shape[0];
    }
  }

  IdArray result_indptr = Concat(res_indptr);
  IdArray result_indices = Concat(res_indices);
  IdArray result_data = has_data ? Concat(res_data) : NullArray();

  return CSRMatrix(
128
129
      src_offset, dst_offset, result_indptr, result_indices, result_data,
      sorted);
130
131
132
}

std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
133
134
135
136
    const CSRMatrix &csr, const uint64_t batch_size,
    const std::vector<uint64_t> &edge_cumsum,
    const std::vector<uint64_t> &src_vertex_cumsum,
    const std::vector<uint64_t> &dst_vertex_cumsum) {
137
138
139
140
141
142
143
  CHECK_EQ(edge_cumsum.size(), batch_size + 1);
  CHECK_EQ(src_vertex_cumsum.size(), batch_size + 1);
  CHECK_EQ(dst_vertex_cumsum.size(), batch_size + 1);
  std::vector<CSRMatrix> ret;
  ret.resize(batch_size);

  for (size_t g = 0; g < batch_size; ++g) {
144
    uint64_t num_src = src_vertex_cumsum[g + 1] - src_vertex_cumsum[g];
145
146
    IdArray result_indptr;
    if (g == 0) {
147
148
      result_indptr =
          IndexSelect(csr.indptr, 0, src_vertex_cumsum[1] + 1) - edge_cumsum[0];
149
    } else {
150
151
152
153
      result_indptr =
          IndexSelect(
              csr.indptr, src_vertex_cumsum[g], src_vertex_cumsum[g + 1] + 1) -
          edge_cumsum[g];
154
155
    }

156
157
158
    IdArray result_indices =
        IndexSelect(csr.indices, edge_cumsum[g], edge_cumsum[g + 1]) -
        dst_vertex_cumsum[g];
159
160
161
162

    IdArray result_data = NullArray();
    // has data index array
    if (CSRHasData(csr)) {
163
164
      result_data = IndexSelect(csr.data, edge_cumsum[g], edge_cumsum[g + 1]) -
                    edge_cumsum[g];
165
166
167
    }

    CSRMatrix sub_csr = CSRMatrix(
168
169
        num_src, dst_vertex_cumsum[g + 1] - dst_vertex_cumsum[g], result_indptr,
        result_indices, result_data, csr.sorted);
170
171
172
173
174
175
    ret[g] = sub_csr;
  }

  return ret;
}

176
CSRMatrix CSRSliceContiguousChunk(
177
178
179
    const CSRMatrix &csr, const std::vector<uint64_t> &edge_range,
    const std::vector<uint64_t> &src_vertex_range,
    const std::vector<uint64_t> &dst_vertex_range) {
180
  int64_t indptr_len = src_vertex_range[1] - src_vertex_range[0] + 1;
181
182
  IdArray result_indptr =
      Full(0, indptr_len, csr.indptr->dtype.bits, csr.indptr->ctx);
183
184
185
186
  IdArray result_indices = NullArray(csr.indptr->dtype, csr.indptr->ctx);
  IdArray result_data = NullArray();
  if (edge_range[1] != edge_range[0]) {
    // The chunk has edges
187
188
189
190
191
    result_indptr =
        IndexSelect(csr.indptr, src_vertex_range[0], src_vertex_range[1] + 1) -
        edge_range[0];
    result_indices = IndexSelect(csr.indices, edge_range[0], edge_range[1]) -
                     dst_vertex_range[0];
192
    if (CSRHasData(csr)) {
193
194
      result_data =
          IndexSelect(csr.data, edge_range[0], edge_range[1]) - edge_range[0];
195
196
197
198
    }
  }

  CSRMatrix sub_csr = CSRMatrix(
199
200
201
      src_vertex_range[1] - src_vertex_range[0],
      dst_vertex_range[1] - dst_vertex_range[0], result_indptr, result_indices,
      result_data, csr.sorted);
202
203
204
205

  return sub_csr;
}

206
207
}  // namespace aten
}  // namespace dgl