THCResponse.cuh 2.49 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
#ifndef THC_RESPONSE_INC
#define THC_RESPONSE_INC

#include "common.cuh"

__global__ void responseKernel(int64_t *color, int64_t *prop, int64_t *row, int64_t *col,
                               int64_t *degree, int64_t *cumDegree, ptrdiff_t nNodes) {
  KERNEL_LOOP(i, nNodes) {
    if (color[i] != -2) { continue; }  // Only visit red nodes.
rusty1s's avatar
rusty1s committed
10
    ptrdiff_t c; bool isDead = true;
rusty1s's avatar
rusty1s committed
11
    for (ptrdiff_t e = cumDegree[i] - degree[i]; e < cumDegree[i]; e++) {
rusty1s's avatar
rusty1s committed
12
13
14
15
16
17
18
      c = col[e];
      if (isDead && color[c] < 0) { isDead = false; }  // Unmatched neighbor found.
      if (color[c] == -1 && prop[c] == i) {  // Match first blue neighbor who proposed to i.
        color[i] = min(i, c);
        color[c] = min(i, c);
        break;
      }
rusty1s's avatar
rusty1s committed
19
    }
rusty1s's avatar
rusty1s committed
20
    if (isDead) { color[i] = i; }  // Mark node as dead.
rusty1s's avatar
rusty1s committed
21
22
23
  }
}

rusty1s's avatar
rusty1s committed
24
25
26
27
28
template<typename T>
__global__ void weightedResponseKernel(int64_t *color, int64_t *prop, int64_t *row, int64_t *col,
                                       T *weight, int64_t *degree, int64_t *cumDegree,
                                       ptrdiff_t nNodes) {
  KERNEL_LOOP(i, nNodes) {
rusty1s's avatar
rusty1s committed
29
30
    if (color[i] != -2) { continue; }  // Only visit red nodes.
    ptrdiff_t c; bool isDead = true;
rusty1s's avatar
bugfix  
rusty1s committed
31
32
    T maxWeight = ScalarConvert<int, T>::to(0), tmp;
    ptrdiff_t matchedValue = -1;
rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    for (ptrdiff_t e = cumDegree[i] - degree[i]; e < cumDegree[i]; e++) {
      c = col[e];
      tmp = weight[e];
      if (isDead && color[c] < 0) { isDead = false; }  // Unmatched neighbor found.
      // Match maximum weighted blue neighbor, who proposed to i.
      if (color[c] == -1 && prop[c] == i && THCNumerics<T>::gt(tmp, maxWeight)) {
        matchedValue = c;
        maxWeight = tmp;
      }
    }
    if (matchedValue >= 0) {
      color[i] = min(i, matchedValue);
      color[c] = min(i, matchedValue);
    }
    if (isDead) { color[i] = i; }  // Mark node as dead.
rusty1s's avatar
rusty1s committed
48
49
50
  }
}

rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
57
58
59
void THCTensor_response(THCState *state, THCudaLongTensor *color, THCudaLongTensor *prop,
                        THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree,
                        THCudaLongTensor *cumDegree) {
  KERNEL_RUN(responseKernel, THCudaLongTensor_nElement(state, color),
             THCudaLongTensor_data(state, color), THCudaLongTensor_data(state, prop),
             THCudaLongTensor_data(state, row), THCudaLongTensor_data(state, col),
             THCudaLongTensor_data(state, degree), THCudaLongTensor_data(state, cumDegree));
}

rusty1s's avatar
rusty1s committed
60
61
62
#include "generic/THCResponse.cuh"
#include "THC/THCGenerateAllTypes.h"

rusty1s's avatar
rusty1s committed
63
#endif  // THC_RESPONSE_INC