".github/vscode:/vscode.git/clone" did not exist on "f11481b921bb7f3eb7b1a9080667549326b7fd17"
graclus_cuda.cu 7.56 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include "graclus_cuda.h"

#include <ATen/cuda/CUDAContext.h>

rusty1s's avatar
rusty1s committed
5
#include "utils.cuh"
rusty1s's avatar
rusty1s committed
6
7
8
9
10

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_P 0.53406

rusty1s's avatar
rusty1s committed
11
12
__device__ bool done_d;
__global__ void init_done_kernel() { done_d = true; }
rusty1s's avatar
rusty1s committed
13
14
15
16
__global__ void colorize_kernel(int64_t *out, const float *bernoulli,
                                int64_t numel) {
  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (thread_idx < numel) {
rusty1s's avatar
rusty1s committed
17
18
19
    if (out[thread_idx] < 0) {
      out[thread_idx] = (int64_t)bernoulli[thread_idx] - 2;
      done_d = false;
rusty1s's avatar
rusty1s committed
20
21
22
23
    }
  }
}

rusty1s's avatar
rusty1s committed
24
bool colorize(torch::Tensor out) {
rusty1s's avatar
rusty1s committed
25
26
27
  auto stream = at::cuda::getCurrentCUDAStream();
  init_done_kernel<<<1, 1, 0, stream>>>();

rusty1s's avatar
rusty1s committed
28
  auto numel = out.size(0);
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
  auto props = torch::full(numel, BLUE_P, out.options().dtype(torch::kFloat));
  auto bernoulli = props.bernoulli();

  colorize_kernel<<<BLOCKS(numel), THREADS, 0, stream>>>(
      out.data_ptr<int64_t>(), bernoulli.data_ptr<float>(), numel);

rusty1s's avatar
rusty1s committed
35
  bool done_h;
rusty1s's avatar
rusty1s committed
36
37
38
39
40
41
42
43
44
45
46
  cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0,
                       cudaMemcpyDeviceToHost);
  return done_h;
}

__global__ void propose_kernel(int64_t *out, int64_t *proposal,
                               const int64_t *rowptr, const int64_t *col,
                               int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (thread_idx < numel) {
rusty1s's avatar
rusty1s committed
47
48
    if (out[thread_idx] != -1)
      return; // Only vist blue nodes.
rusty1s's avatar
rusty1s committed
49
50
51

    bool has_unmatched_neighbor = false;

rusty1s's avatar
rusty1s committed
52
    for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
rusty1s's avatar
rusty1s committed
53
54
55
56
57
58
      auto v = col[i];

      if (out[v] < 0)
        has_unmatched_neighbor = true; // Unmatched neighbor found.

      if (out[v] == -2) {
rusty1s's avatar
rusty1s committed
59
        proposal[thread_idx] = v; // Propose to first red neighbor.
rusty1s's avatar
rusty1s committed
60
61
62
63
64
        break;
      }
    }

    if (!has_unmatched_neighbor)
rusty1s's avatar
rusty1s committed
65
      out[thread_idx] = thread_idx;
rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
72
73
74
75
76
  }
}

template <typename scalar_t>
__global__ void weighted_propose_kernel(int64_t *out, int64_t *proposal,
                                        const int64_t *rowptr,
                                        const int64_t *col,
                                        const scalar_t *weight, int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (thread_idx < numel) {
rusty1s's avatar
rusty1s committed
77
78
    if (out[thread_idx] != -1)
      return; // Only vist blue nodes.
rusty1s's avatar
rusty1s committed
79
80
81
82
83

    bool has_unmatched_neighbor = false;
    int64_t v_max = -1;
    scalar_t w_max = 0;

rusty1s's avatar
rusty1s committed
84
    for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
rusty1s's avatar
rusty1s committed
85
86
87
88
89
90
91
92
93
94
95
96
      auto v = col[i];

      if (out[v] < 0)
        has_unmatched_neighbor = true; // Unmatched neighbor found.

      // Find maximum weighted red neighbor.
      if (out[v] == -2 && weight[i] >= w_max) {
        v_max = v;
        w_max = weight[i];
      }
    }

rusty1s's avatar
rusty1s committed
97
    proposal[thread_idx] = v_max; // Propose.
rusty1s's avatar
rusty1s committed
98
99

    if (!has_unmatched_neighbor)
rusty1s's avatar
rusty1s committed
100
      out[thread_idx] = thread_idx;
rusty1s's avatar
rusty1s committed
101
102
103
104
  }
}

void propose(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
rusty1s's avatar
rusty1s committed
105
106
             torch::Tensor col,
             torch::optional<torch::Tensor> optional_weight) {
rusty1s's avatar
rusty1s committed
107
108
109
110
111
112

  auto stream = at::cuda::getCurrentCUDAStream();

  if (!optional_weight.has_value()) {
    propose_kernel<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
        out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rusty1s's avatar
rusty1s committed
113
        rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
rusty1s's avatar
rusty1s committed
114
  } else {
rusty1s's avatar
rusty1s committed
115
    auto weight = optional_weight.value();
Matthias Fey's avatar
Matthias Fey committed
116
117
    auto scalar_type = weight.scalar_type();
    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
rusty1s's avatar
rusty1s committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
      weighted_propose_kernel<scalar_t>
          <<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
              out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
              rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
              weight.data_ptr<scalar_t>(), out.numel());
    });
  }
}

__global__ void respond_kernel(int64_t *out, const int64_t *proposal,
                               const int64_t *rowptr, const int64_t *col,
                               int64_t numel) {
  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (thread_idx < numel) {
rusty1s's avatar
rusty1s committed
132
133
    if (out[thread_idx] != -2)
      return; // Only vist red nodes.
rusty1s's avatar
rusty1s committed
134
135
136

    bool has_unmatched_neighbor = false;

rusty1s's avatar
rusty1s committed
137
    for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
rusty1s's avatar
rusty1s committed
138
139
140
141
142
      auto v = col[i];

      if (out[v] < 0)
        has_unmatched_neighbor = true; // Unmatched neighbor found.

rusty1s's avatar
rusty1s committed
143
      if (out[v] == -1 && proposal[v] == thread_idx) {
rusty1s's avatar
rusty1s committed
144
        // Match first blue neighbhor v which proposed to u.
rusty1s's avatar
rusty1s committed
145
146
        out[thread_idx] = min(thread_idx, v);
        out[v] = min(thread_idx, v);
rusty1s's avatar
rusty1s committed
147
148
149
150
151
        break;
      }
    }

    if (!has_unmatched_neighbor)
rusty1s's avatar
rusty1s committed
152
      out[thread_idx] = thread_idx;
rusty1s's avatar
rusty1s committed
153
154
155
156
157
158
159
160
161
162
  }
}

template <typename scalar_t>
__global__ void weighted_respond_kernel(int64_t *out, const int64_t *proposal,
                                        const int64_t *rowptr,
                                        const int64_t *col,
                                        const scalar_t *weight, int64_t numel) {
  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (thread_idx < numel) {
rusty1s's avatar
rusty1s committed
163
164
    if (out[thread_idx] != -2)
      return; // Only vist red nodes.
rusty1s's avatar
rusty1s committed
165
166
167
168
169

    bool has_unmatched_neighbor = false;
    int64_t v_max = -1;
    scalar_t w_max = 0;

rusty1s's avatar
rusty1s committed
170
    for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
rusty1s's avatar
rusty1s committed
171
172
173
174
175
      auto v = col[i];

      if (out[v] < 0)
        has_unmatched_neighbor = true; // Unmatched neighbor found.

rusty1s's avatar
rusty1s committed
176
      if (out[v] == -1 && proposal[v] == thread_idx && weight[i] >= w_max) {
rusty1s's avatar
rusty1s committed
177
178
179
180
181
182
183
        // Find maximum weighted blue neighbhor v which proposed to u.
        v_max = v;
        w_max = weight[i];
      }
    }

    if (v_max >= 0) {
rusty1s's avatar
rusty1s committed
184
185
      out[thread_idx] = min(thread_idx, v_max); // Match neighbors.
      out[v_max] = min(thread_idx, v_max);
rusty1s's avatar
rusty1s committed
186
187
188
    }

    if (!has_unmatched_neighbor)
rusty1s's avatar
rusty1s committed
189
      out[thread_idx] = thread_idx;
rusty1s's avatar
rusty1s committed
190
191
192
193
  }
}

void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
rusty1s's avatar
rusty1s committed
194
195
             torch::Tensor col,
             torch::optional<torch::Tensor> optional_weight) {
rusty1s's avatar
rusty1s committed
196
197
198
199
200
201
202
203

  auto stream = at::cuda::getCurrentCUDAStream();

  if (!optional_weight.has_value()) {
    respond_kernel<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
        out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
        rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
  } else {
rusty1s's avatar
rusty1s committed
204
    auto weight = optional_weight.value();
Matthias Fey's avatar
Matthias Fey committed
205
206
    auto scalar_type = weight.scalar_type();
    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
rusty1s's avatar
rusty1s committed
207
208
209
210
211
      weighted_respond_kernel<scalar_t>
          <<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
              out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
              rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
              weight.data_ptr<scalar_t>(), out.numel());
rusty1s's avatar
rusty1s committed
212
213
214
    });
  }
}
rusty1s's avatar
rusty1s committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
                           torch::optional<torch::Tensor> optional_weight) {
  CHECK_CUDA(rowptr);
  CHECK_CUDA(col);
  CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1);
  if (optional_weight.has_value()) {
    CHECK_CUDA(optional_weight.value());
    CHECK_INPUT(optional_weight.value().dim() == 1);
    CHECK_INPUT(optional_weight.value().numel() == col.numel());
  }
  cudaSetDevice(rowptr.get_device());

  int64_t num_nodes = rowptr.numel() - 1;
  auto out = torch::full(num_nodes, -1, rowptr.options());
  auto proposal = torch::full(num_nodes, -1, rowptr.options());

  while (!colorize(out)) {
    propose(out, proposal, rowptr, col, optional_weight);
    respond(out, proposal, rowptr, col, optional_weight);
  }

  return out;
}