"docs/en/notes/compatibility.md" did not exist on "cbf194fa4b9b4b9d9fc3a25f4872062addcd8c9c"
proposal.cuh 2.7 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#pragma once

#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
5
6
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

__global__ void propose_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
                               int64_t *__restrict row,
                               int64_t *__restrict__ col, size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (int64_t u = index; u < numel; u += stride) {
    if (cluster[u] != -1)
      continue; // Only vist blue nodes.

    bool has_unmatched_neighbor = false;

    for (int64_t i = row[u]; i < row[u + 1]; i++) {
      auto v = col[i];

      if (cluster[v] < 0)
        has_unmatched_neighbor = true; // Unmatched neighbor found.

      if (cluster[v] == -2) {
        proposal[u] = v; // Propose to first red neighbor.
        break;
      }
    }

    if (!has_unmatched_neighbor)
      cluster[u] = u;
  }
}

void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
             at::Tensor col) {
  propose_kernel<<<BLOCKS(cluster.numel()), THREADS>>>(
rusty1s's avatar
rusty1s committed
41
42
      cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
      row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), cluster.numel());
rusty1s's avatar
rusty1s committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
}

template <typename scalar_t>
__global__ void propose_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
                               int64_t *__restrict row,
                               int64_t *__restrict__ col,
                               scalar_t *__restrict__ weight, size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (int64_t u = index; u < numel; u += stride) {
    if (cluster[u] != -1)
      continue; // Only vist blue nodes.

    bool has_unmatched_neighbor = false;
    int64_t v_max = -1;
    scalar_t w_max = 0;

    for (int64_t i = row[u]; i < row[u + 1]; i++) {
      auto v = col[i];

      if (cluster[v] < 0)
        has_unmatched_neighbor = true; // Unmatched neighbor found.

      // Find maximum weighted red neighbor.
      if (cluster[v] == -2 && weight[i] >= w_max) {
        v_max = v;
        w_max = weight[i];
      }
    }

    proposal[u] = v_max; // Propose.

    if (!has_unmatched_neighbor)
      cluster[u] = u;
  }
}

void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
             at::Tensor col, at::Tensor weight) {
rusty1s's avatar
rusty1s committed
82
  AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "propose_kernel", [&] {
rusty1s's avatar
rusty1s committed
83
    propose_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
rusty1s's avatar
rusty1s committed
84
85
86
        cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
        row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(),
        weight.DATA_PTR<scalar_t>(), cluster.numel());
rusty1s's avatar
rusty1s committed
87
88
  });
}