"cuda/utils/timer.hh" did not exist on "4c90e6e8b81cd4b914c5b53590fbae0a5a4df835"
edge_coarsening_impl.cu 8.75 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 *  Copyright (c) 2019 by Contributors
 * \file geometry/cuda/edge_coarsening_impl.cu
 * \brief Edge coarsening CUDA implementation
 */
#include <dgl/array.h>
#include <dgl/random.h>
#include <dmlc/thread_local.h>
9
#include <curand_kernel.h>
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <cstdint>
#include "../geometry_op.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../../array/cuda/utils.h"

#define BLOCKS(N, T) (N + T - 1) / T

namespace dgl {
namespace geometry {
namespace impl {

constexpr float BLUE_P = 0.53406;
constexpr int BLUE = -1;
constexpr int RED = -2;
constexpr int EMPTY_IDX = -1;

__device__ bool done_d;
__global__ void init_done_kernel() { done_d = true; }

29
30
31
32
33
34
35
36
37
__global__ void generate_uniform_kernel(float *ret_values, size_t num, uint64_t seed) {
  size_t id = blockIdx.x * blockDim.x + threadIdx.x;
  if (id < num) {
    curandState state;
    curand_init(seed, id, 0, &state);
    ret_values[id] = curand_uniform(&state);
  }
}

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
template <typename IdType>
__global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *result) {
  const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < num_elem) {
    if (result[idx] < 0) {  // if unmatched
      result[idx] = (prop[idx] > BLUE_P) ? RED : BLUE;
      done_d = false;
    }
  }
}

template <typename FloatType, typename IdType>
__global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indices,
                                        const FloatType *weights, int64_t num_elem,
                                        IdType *proposal, IdType *result) {
  const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < num_elem) {
    if (result[idx] != BLUE) return;

    bool has_unmatched_neighbor = false;
    FloatType weight_max = 0.;
    IdType v_max = EMPTY_IDX;

    for (IdType i = indptr[idx]; i < indptr[idx + 1]; ++i) {
      auto v = indices[i];

      if (result[v] < 0)
        has_unmatched_neighbor = true;
      if (result[v] == RED && weights[i] >= weight_max) {
        v_max = v;
        weight_max = weights[i];
      }
    }

    proposal[idx] = v_max;
    if (!has_unmatched_neighbor)
      result[idx] = idx;
  }
}

template <typename FloatType, typename IdType>
__global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indices,
                                        const FloatType *weights, int64_t num_elem,
                                        IdType *proposal, IdType *result) {
  const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < num_elem) {
    if (result[idx] != RED) return;

    bool has_unmatched_neighbors = false;
    IdType v_max = -1;
    FloatType weight_max = 0.;

    for (IdType i = indptr[idx]; i < indptr[idx + 1]; ++i) {
      auto v = indices[i];

      if (result[v] < 0) {
        has_unmatched_neighbors = true;
      }
      if (result[v] == BLUE
          && proposal[v] == idx
          && weights[i] >= weight_max) {
        v_max = v;
        weight_max = weights[i];
      }
    }
    if (v_max >= 0) {
      result[v_max] = min(idx, v_max);
      result[idx] = min(idx, v_max);
    }

    if (!has_unmatched_neighbors)
      result[idx] = idx;
  }
}

/*! \brief The colorize procedure. This procedure randomly marks unmarked
 * nodes with BLUE(-1) and RED(-2) and checks whether the node matching
 * process has finished.
 */
template<typename IdType>
118
bool Colorize(IdType * result_data, int64_t num_nodes) {
119
120
121
122
123
124
  // initial done signal
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, thr_entry->stream);

  // generate color prop for each node
  float *prop;
125
126
127
  uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
  auto num_threads = cuda::FindNumThreads(num_nodes);
  auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
128
  CUDA_CALL(cudaMalloc(reinterpret_cast<void **>(&prop), num_nodes * sizeof(float)));
129
130
  CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, thr_entry->stream,
                   prop, num_nodes, seed);
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

  // call kernel
  CUDA_KERNEL_CALL(colorize_kernel, num_blocks, num_threads, 0, thr_entry->stream,
                   prop, num_nodes, result_data);
  bool done_h = false;
  CUDA_CALL(cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost));
  CUDA_CALL(cudaFree(prop));
  return done_h;
}

/*! \brief Weighted neighbor matching procedure (GPU version).
 * This implementation is from `A GPU Algorithm for Greedy Graph Matching
 * <http://www.staff.science.uu.nl/~bisse101/Articles/match12.pdf>`__
 * 
 * This algorithm has three parts: colorize, propose and respond.
 * In colorize procedure, each unmarked node will be marked as BLUE or
 * RED randomly. If all nodes are marked, finish and return.
 * In propose procedure, each BLUE node will propose to the RED
 * neighbor with the largest weight (or randomly choose one if without weight).
 * If all its neighbors are marked, mark this node with its id.
 * In respond procedure, each RED node will respond to the BLUE neighbor
 * that has proposed to it and has the largest weight. If all neighbors
 * are marked, mark this node with its id. Else match this (BLUE, RED) node
 * pair and mark them with the smaller id between them.
 */
template <DLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
159
160
161
  const auto& ctx = result->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  device->SetDevice(ctx);
162
163
164

  // create proposal tensor
  const int64_t num_nodes = result->shape[0];
165
  IdArray proposal = aten::Full(-1, num_nodes, sizeof(IdType) * 8, ctx);
166
167
168
169
170
171
172
173
174
175

  // get data ptrs
  IdType *indptr_data = static_cast<IdType*>(csr.indptr->data);
  IdType *indices_data = static_cast<IdType*>(csr.indices->data);
  IdType *result_data = static_cast<IdType*>(result->data);
  IdType *proposal_data = static_cast<IdType*>(proposal->data);
  FloatType *weight_data = static_cast<FloatType*>(weight->data);

  auto num_threads = cuda::FindNumThreads(num_nodes);
  auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
176
  while (!Colorize<IdType>(result_data, num_nodes)) {
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, thr_entry->stream,
                     indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
    CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, thr_entry->stream,
                     indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
  }
}
template void WeightedNeighborMatching<kDLGPU, float, int32_t>(
  const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLGPU, float, int64_t>(
  const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLGPU, double, int32_t>(
  const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLGPU, double, int64_t>(
  const aten::CSRMatrix &csr, const NDArray weight, IdArray result);

/*! \brief Unweighted neighbor matching procedure (GPU version).
 * Instead of directly sample neighbors, we assign each neighbor
 * with a random weight. We use random weight for 2 reasons:
 *  1. Random sample for each node in GPU is expensive. Although
 *     we can perform a global group-wise (neighborhood of each
 *     node as a group) random permutation as in CPU version,
 *     it still cost too much compared to directly using random weights.
 *  2. Graph is sparse, thus neighborhood of each node is small,
 *     which is suitable for GPU implementation.
 */
template <DLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
  const int64_t num_edges = csr.indices->shape[0];
205
206
207
  const auto& ctx = result->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  device->SetDevice(ctx);
208
209
210
211

  // generate random weights
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  NDArray weight = NDArray::Empty(
212
    {num_edges}, DLDataType{kDLFloat, sizeof(float) * 8, 1}, ctx);
213
  float *weight_data = static_cast<float*>(weight->data);
214
215
216
217
218
  uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
  auto num_threads = cuda::FindNumThreads(num_edges);
  auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads));
  CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, thr_entry->stream,
                   weight_data, num_edges, seed);
219
220
221
222
223
224
225
226
227

  WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);
}
template void NeighborMatching<kDLGPU, int32_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDLGPU, int64_t>(const aten::CSRMatrix &csr, IdArray result);

}  // namespace impl
}  // namespace geometry
}  // namespace dgl