"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "0744378dc0c34d53ec4b1155f2cf87364a0754b1"
rowwise_sampling.cc 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/rowwise_sampling.cc
 * \brief rowwise sampling
 */
#include <dgl/random.h>
#include <numeric>
#include "./rowwise_pick.h"

namespace dgl {
namespace aten {
namespace impl {
namespace {
// Equivalent to numpy expression: array[idx[off:off + len]]
template <typename IdxType, typename FloatType>
inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data,
                              IdxType off, IdxType len) {
  const FloatType* array_data = static_cast<FloatType*>(array->data);
  FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);
  FloatType* ret_data = static_cast<FloatType*>(ret->data);
  for (int64_t j = 0; j < len; ++j) {
    if (idx_data)
      ret_data[j] = array_data[idx_data[off + j]];
    else
      ret_data[j] = array_data[off + j];
  }
  return ret;
}

template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingPickFn(
    int64_t num_samples, FloatArray prob, bool replace) {
  PickFn<IdxType> pick_fn = [prob, num_samples, replace]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
      FloatArray prob_selected = DoubleSlice<IdxType, FloatType>(prob, data, off, len);
38
39
      RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
          num_samples, prob_selected, out_idx, replace);
40
      for (int64_t j = 0; j < num_samples; ++j) {
41
        out_idx[j] += off;
42
43
44
45
46
      }
    };
  return pick_fn;
}

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
template <typename IdxType, typename FloatType>
inline RangePickFn<IdxType> GetSamplingRangePickFn(
    int64_t num_samples, FloatArray prob, bool replace) {
  RangePickFn<IdxType> pick_fn = [prob, num_samples, replace]
    (IdxType off, IdxType et_offset, IdxType et_len,
    const std::vector<IdxType> &et_idx,
    const IdxType* data, IdxType* out_idx) {
      const FloatType* p_data = static_cast<FloatType*>(prob->data);
      FloatArray probs = FloatArray::Empty({et_len}, prob->dtype, prob->ctx);
      FloatType* probs_data = static_cast<FloatType*>(probs->data);
      for (int64_t j = 0; j < et_len; ++j) {
        if (data)
          probs_data[j] = p_data[data[off+et_idx[et_offset+j]]];
        else
          probs_data[j] = p_data[off+et_idx[et_offset+j]];
      }

      RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
          num_samples, probs, out_idx, replace);
    };
  return pick_fn;
}

70
71
72
73
74
75
76
template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn(
    int64_t num_samples, bool replace) {
  PickFn<IdxType> pick_fn = [num_samples, replace]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
77
78
      RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
          num_samples, len, out_idx, replace);
79
      for (int64_t j = 0; j < num_samples; ++j) {
80
        out_idx[j] += off;
81
82
83
84
      }
    };
  return pick_fn;
}
85

86
87
88
89
90
91
92
93
94
95
96
97
98
template <typename IdxType>
inline RangePickFn<IdxType> GetSamplingUniformRangePickFn(
    int64_t num_samples, bool replace) {
  RangePickFn<IdxType> pick_fn = [num_samples, replace]
    (IdxType off, IdxType et_offset, IdxType et_len,
    const std::vector<IdxType> &et_idx,
    const IdxType* data, IdxType* out_idx) {
      RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
          num_samples, et_len, out_idx, replace);
    };
  return pick_fn;
}

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
  PickFn<IdxType> pick_fn = [num_samples, split, bias, replace]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
    const IdxType *tag_offset = static_cast<IdxType *>(split->data) + rowid * split->shape[1];
    RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
            num_samples, tag_offset, bias, out_idx, replace);
    for (int64_t j = 0; j < num_samples; ++j) {
      out_idx[j] += off;
    }
  };
  return pick_fn;
}

116
117
118
119
120
121
122
}  // namespace

/////////////////////////////// CSR ///////////////////////////////

template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
                             FloatArray prob, bool replace) {
123
  CHECK(prob.defined());
124
125
126
127
128
129
130
131
132
133
134
135
136
  auto pick_fn = GetSamplingPickFn<IdxType, FloatType>(num_samples, prob, replace);
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}

template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, float>(
    CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, float>(
    CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, double>(
    CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, double>(
    CSRMatrix, IdArray, int64_t, FloatArray, bool);

137
138
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
139
140
                                     int64_t num_samples, FloatArray prob,
                                     bool replace, bool etype_sorted) {
141
142
  CHECK(prob.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
143
  return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
144
145
146
}

template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, float>(
147
    CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
148
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, float>(
149
    CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
150
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, double>(
151
    CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
152
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, double>(
153
    CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
154

155
156
157
158
159
160
161
162
163
164
165
166
template <DLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
                                    int64_t num_samples, bool replace) {
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}

template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int32_t>(
    CSRMatrix, IdArray, int64_t, bool);
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>(
    CSRMatrix, IdArray, int64_t, bool);

167
168
template <DLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes,
169
170
                                            int64_t num_samples,
                                            bool replace, bool etype_sorted) {
171
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
172
  return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
173
174
175
}

template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
176
    CSRMatrix, IdArray, IdArray, int64_t, bool, bool);
177
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
178
    CSRMatrix, IdArray, IdArray, int64_t, bool, bool);
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
    CSRMatrix mat,
    IdArray rows,
    int64_t num_samples,
    NDArray tag_offset,
    FloatArray bias,
    bool replace
) {
  auto pick_fn = GetSamplingBiasedPickFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}

template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, float>(
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, float>(
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, double>(
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, double>(
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);


207
208
209
210
211
/////////////////////////////// COO ///////////////////////////////

template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
                             FloatArray prob, bool replace) {
212
  CHECK(prob.defined());
213
214
215
216
217
218
219
220
221
222
223
224
225
  auto pick_fn = GetSamplingPickFn<IdxType, FloatType>(num_samples, prob, replace);
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
}

template COOMatrix COORowWiseSampling<kDLCPU, int32_t, float>(
    COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, float>(
    COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, double>(
    COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, double>(
    COOMatrix, IdArray, int64_t, FloatArray, bool);

226
227
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
228
229
                                     int64_t num_samples, FloatArray prob,
                                     bool replace, bool etype_sorted) {
230
231
  CHECK(prob.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
232
  return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
233
234
235
}

template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, float>(
236
    COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
237
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, float>(
238
    COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
239
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, double>(
240
    COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
241
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, double>(
242
    COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
243

244
245
246
247
248
249
250
251
252
253
254
255
template <DLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
                                    int64_t num_samples, bool replace) {
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
}

template COOMatrix COORowWiseSamplingUniform<kDLCPU, int32_t>(
    COOMatrix, IdArray, int64_t, bool);
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int64_t>(
    COOMatrix, IdArray, int64_t, bool);

256
257
template <DLDeviceType XPU, typename IdxType>
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes,
258
                                    int64_t num_samples, bool replace, bool etype_sorted) {
259
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
260
  return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
261
262
263
}

template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
264
    COOMatrix, IdArray, IdArray, int64_t, bool, bool);
265
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
266
    COOMatrix, IdArray, IdArray, int64_t, bool, bool);
267

268
269
270
}  // namespace impl
}  // namespace aten
}  // namespace dgl