negative_sampling.cu 7.59 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file array/cuda/negative_sampling.cu
 * @brief rowwise sampling
5
6
 */

7
#include <curand_kernel.h>
8
9
#include <dgl/array.h>
#include <dgl/array_iterator.h>
10
#include <dgl/random.h>
11

12
13
#include <cub/cub.cuh>

14
#include "../../runtime/cuda/cuda_common.h"
15
16
17
18
19
20
21
22
23
24
25
26
#include "./utils.h"

using namespace dgl::runtime;

namespace dgl {
namespace aten {
namespace impl {

namespace {

template <typename IdType>
__global__ void _GlobalUniformNegativeSamplingKernel(
27
28
29
30
    const IdType* __restrict__ indptr, const IdType* __restrict__ indices,
    IdType* __restrict__ row, IdType* __restrict__ col, int64_t num_row,
    int64_t num_col, int64_t num_samples, int num_trials,
    bool exclude_self_loops, int32_t random_seed) {
31
32
33
  int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;

34
35
  curandStatePhilox4_32_10_t
      rng;  // this allows generating 4 32-bit ints at a time
36
37
38
39
40
  curand_init(random_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);

  while (tx < num_samples) {
    for (int i = 0; i < num_trials; ++i) {
      uint4 result = curand4(&rng);
41
42
43
44
45
46
47
      // Turns out that result.x is always 0 with the above RNG.
      uint64_t y_hi = result.y >> 16;
      uint64_t y_lo = result.y & 0xFFFF;
      uint64_t z = static_cast<uint64_t>(result.z);
      uint64_t w = static_cast<uint64_t>(result.w);
      int64_t u = static_cast<int64_t>(((y_lo << 32L) | z) % num_row);
      int64_t v = static_cast<int64_t>(((y_hi << 32L) | w) % num_col);
48

49
      if (exclude_self_loops && (u == v)) continue;
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

      // binary search of v among indptr[u:u+1]
      int64_t b = indptr[u], e = indptr[u + 1] - 1;
      bool found = false;
      while (b <= e) {
        int64_t m = (b + e) / 2;
        if (indices[m] == v) {
          found = true;
          break;
        } else if (indices[m] < v) {
          b = m + 1;
        } else {
          e = m - 1;
        }
      }

      if (!found) {
        row[tx] = u;
        col[tx] = v;
        break;
      }
    }

    tx += stride_x;
  }
}

template <typename DType>
struct IsNotMinusOne {
79
  __device__ __forceinline__ bool operator()(const std::pair<DType, DType>& a) {
80
81
82
83
    return a.first != -1;
  }
};

84
/**
85
 * @brief Sort ordered pairs in ascending order, using \a tmp_major and \a
86
 * tmp_minor as temporary buffers, each with \a n elements.
87
88
89
 */
template <typename IdType>
void SortOrderedPairs(
90
91
    runtime::DeviceAPI* device, DGLContext ctx, IdType* major, IdType* minor,
    IdType* tmp_major, IdType* tmp_minor, int64_t n, cudaStream_t stream) {
92
93
  // Sort ordered pairs in lexicographical order by two radix sorts since
  // cub's radix sorts are stable.
94
95
  // We need a 2*n auxiliary storage to store the results form the first radix
  // sort.
96
97
98
99
100
101
  size_t s1 = 0, s2 = 0;
  void* tmp1 = nullptr;
  void* tmp2 = nullptr;

  // Radix sort by minor key first, reorder the major key in the progress.
  CUDA_CALL(cub::DeviceRadixSort::SortPairs(
102
103
      tmp1, s1, minor, tmp_minor, major, tmp_major, n, 0, sizeof(IdType) * 8,
      stream));
104
105
  tmp1 = device->AllocWorkspace(ctx, s1);
  CUDA_CALL(cub::DeviceRadixSort::SortPairs(
106
107
      tmp1, s1, minor, tmp_minor, major, tmp_major, n, 0, sizeof(IdType) * 8,
      stream));
108
109
110

  // Radix sort by major key next.
  CUDA_CALL(cub::DeviceRadixSort::SortPairs(
111
112
113
114
      tmp2, s2, tmp_major, major, tmp_minor, minor, n, 0, sizeof(IdType) * 8,
      stream));
  tmp2 = (s2 > s1) ? device->AllocWorkspace(ctx, s2)
                   : tmp1;  // reuse buffer if s2 <= s1
115
  CUDA_CALL(cub::DeviceRadixSort::SortPairs(
116
117
      tmp2, s2, tmp_major, major, tmp_minor, minor, n, 0, sizeof(IdType) * 8,
      stream));
118

119
  if (tmp1 != tmp2) device->FreeWorkspace(ctx, tmp2);
120
121
122
123
124
  device->FreeWorkspace(ctx, tmp1);
}

};  // namespace

125
template <DGLDeviceType XPU, typename IdType>
126
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
127
128
    const CSRMatrix& csr, int64_t num_samples, int num_trials,
    bool exclude_self_loops, bool replace, double redundancy) {
129
130
131
132
  auto ctx = csr.indptr->ctx;
  auto dtype = csr.indptr->dtype;
  const int64_t num_row = csr.num_rows;
  const int64_t num_col = csr.num_cols;
133
134
  const int64_t num_actual_samples =
      static_cast<int64_t>(num_samples * (1 + redundancy));
135
136
137
138
139
140
141
142
143
  IdArray row = Full<IdType>(-1, num_actual_samples, ctx);
  IdArray col = Full<IdType>(-1, num_actual_samples, ctx);
  IdArray out_row = IdArray::Empty({num_actual_samples}, dtype, ctx);
  IdArray out_col = IdArray::Empty({num_actual_samples}, dtype, ctx);
  IdType* row_data = row.Ptr<IdType>();
  IdType* col_data = col.Ptr<IdType>();
  IdType* out_row_data = out_row.Ptr<IdType>();
  IdType* out_col_data = out_col.Ptr<IdType>();
  auto device = runtime::DeviceAPI::Get(ctx);
144
  cudaStream_t stream = runtime::getCurrentCUDAStream();
145
146
147
148
149
  const int nt = cuda::FindNumThreads(num_actual_samples);
  const int nb = (num_actual_samples + nt - 1) / nt;
  std::pair<IdArray, IdArray> result;
  int64_t num_out;

150
151
152
153
154
  CUDA_KERNEL_CALL(
      _GlobalUniformNegativeSamplingKernel, nb, nt, 0, stream,
      csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), row_data, col_data,
      num_row, num_col, num_actual_samples, num_trials, exclude_self_loops,
      RandomEngine::ThreadLocal()->RandInt32());
155
156

  size_t tmp_size = 0;
157
158
  int64_t* num_out_cuda =
      static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));
159
160
161
162
  IsNotMinusOne<IdType> op;
  PairIterator<IdType> begin(row_data, col_data);
  PairIterator<IdType> out_begin(out_row_data, out_col_data);
  CUDA_CALL(cub::DeviceSelect::If(
163
164
      nullptr, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op,
      stream));
165
166
  void* tmp = device->AllocWorkspace(ctx, tmp_size);
  CUDA_CALL(cub::DeviceSelect::If(
167
168
      tmp, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op,
      stream));
169
  num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda);
170
171
172
173
174
175
176
177
178

  if (!replace) {
    IdArray unique_row = IdArray::Empty({num_out}, dtype, ctx);
    IdArray unique_col = IdArray::Empty({num_out}, dtype, ctx);
    IdType* unique_row_data = unique_row.Ptr<IdType>();
    IdType* unique_col_data = unique_col.Ptr<IdType>();
    PairIterator<IdType> unique_begin(unique_row_data, unique_col_data);

    SortOrderedPairs(
179
180
        device, ctx, out_row_data, out_col_data, unique_row_data,
        unique_col_data, num_out, stream);
181
182
183
184

    size_t tmp_size_unique = 0;
    void* tmp_unique = nullptr;
    CUDA_CALL(cub::DeviceSelect::Unique(
185
186
187
188
189
        nullptr, tmp_size_unique, out_begin, unique_begin, num_out_cuda,
        num_out, stream));
    tmp_unique = (tmp_size_unique > tmp_size)
                     ? device->AllocWorkspace(ctx, tmp_size_unique)
                     : tmp;  // reuse buffer
190
    CUDA_CALL(cub::DeviceSelect::Unique(
191
192
        tmp_unique, tmp_size_unique, out_begin, unique_begin, num_out_cuda,
        num_out, stream));
193
    num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda);
194
195

    num_out = std::min(num_samples, num_out);
196
197
198
    result = {
        unique_row.CreateView({num_out}, dtype),
        unique_col.CreateView({num_out}, dtype)};
199

200
    if (tmp_unique != tmp) device->FreeWorkspace(ctx, tmp_unique);
201
202
  } else {
    num_out = std::min(num_samples, num_out);
203
204
205
    result = {
        out_row.CreateView({num_out}, dtype),
        out_col.CreateView({num_out}, dtype)};
206
207
208
209
210
211
212
  }

  device->FreeWorkspace(ctx, tmp);
  device->FreeWorkspace(ctx, num_out_cuda);
  return result;
}

213
214
215
216
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<
    kDGLCUDA, int32_t>(const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<
    kDGLCUDA, int64_t>(const CSRMatrix&, int64_t, int, bool, bool, double);
217
218
219
220

};  // namespace impl
};  // namespace aten
};  // namespace dgl