response.cuh 2.9 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#pragma once

#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
5
6
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

__global__ void respond_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
                               int64_t *__restrict row,
                               int64_t *__restrict__ col, size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (int64_t u = index; u < numel; u += stride) {
    if (cluster[u] != -2)
      continue; // Only vist red nodes.

    bool has_unmatched_neighbor = false;

    for (int64_t i = row[u]; i < row[u + 1]; i++) {
      auto v = col[i];

      if (cluster[v] < 0)
        has_unmatched_neighbor = true; // Unmatched neighbor found.

      if (cluster[v] == -1 && proposal[v] == u) {
        // Match first blue neighbhor v which proposed to u.
        cluster[u] = min(u, v);
        cluster[v] = min(u, v);
        break;
      }
    }

    if (!has_unmatched_neighbor)
      cluster[u] = u;
  }
}

void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
             at::Tensor col) {
  respond_kernel<<<BLOCKS(cluster.numel()), THREADS>>>(
rusty1s's avatar
rusty1s committed
43
44
      cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
      row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), cluster.numel());
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
}

template <typename scalar_t>
__global__ void respond_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
                               int64_t *__restrict row,
                               int64_t *__restrict__ col,
                               scalar_t *__restrict__ weight, size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (int64_t u = index; u < numel; u += stride) {
    if (cluster[u] != -2)
      continue; // Only vist red nodes.

    bool has_unmatched_neighbor = false;
    int64_t v_max = -1;
    scalar_t w_max = 0;

    for (int64_t i = row[u]; i < row[u + 1]; i++) {
      auto v = col[i];

      if (cluster[v] < 0)
        has_unmatched_neighbor = true; // Unmatched neighbor found.

      if (cluster[v] == -1 && proposal[v] == u && weight[i] >= w_max) {
        // Find maximum weighted blue neighbhor v which proposed to u.
        v_max = v;
        w_max = weight[i];
      }
    }

    if (v_max >= 0) {
      cluster[u] = min(u, v_max); // Match neighbors.
      cluster[v_max] = min(u, v_max);
    }

    if (!has_unmatched_neighbor)
      cluster[u] = u;
  }
}

void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
             at::Tensor col, at::Tensor weight) {
rusty1s's avatar
rusty1s committed
87
  AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "respond_kernel", [&] {
rusty1s's avatar
rusty1s committed
88
    respond_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
rusty1s's avatar
rusty1s committed
89
90
91
        cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
        row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(),
        weight.DATA_PTR<scalar_t>(), cluster.numel());
rusty1s's avatar
rusty1s committed
92
93
  });
}