"research/object_detection/g3doc/evaluation_protocols.md" did not exist on "676a4f70c20020ed41b533e0c331f115eeffe9a3"
matmul_kernel.cu 3.72 KB
Newer Older
rusty1s's avatar
to csr  
rusty1s committed
1
2
3
4
#include <ATen/ATen.h>

#include <cusparse.h>

rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#define CSRGEMM(TYPE, ...)                                                     \
  [&] {                                                                        \
    const at::Type &the_type = TYPE;                                           \
    switch (the_type.scalarType()) {                                           \
    case at::ScalarType::Float: {                                              \
      using scalar_t = float;                                                  \
      return cusparseScsrgemm(__VA_ARGS__);                                    \
    }                                                                          \
    case at::ScalarType::Double: {                                             \
      using scalar_t = double;                                                 \
      return cusparseDcsrgemm(__VA_ARGS__);                                    \
    }                                                                          \
    default:                                                                   \
      AT_ERROR("Not implemented for '%s'", the_type.toString());               \
    }                                                                          \
  }()

rusty1s's avatar
to csr  
rusty1s committed
22
23
24
25
26
27
28
29
static cusparseHandle_t cusparse_handle = 0;

static void init_cusparse() {
  if (cusparse_handle == 0) {
    cusparseStatus_t status = cusparseCreate(&cusparse_handle);
  }
}

rusty1s's avatar
rusty1s committed
30
31
32
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
            at::Tensor valueB, int m, int k, int n) {
rusty1s's avatar
rusty1s committed
33
34
  init_cusparse();

rusty1s's avatar
rusty1s committed
35
36
  auto nnzA = valueA.size(0);
  auto nnzB = valueB.size(0);
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
39
  indexA = indexA.toType(at::kInt);
  indexB = indexB.toType(at::kInt);
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
42
  // Convert A to CSR format.
  auto row_ptrA = at::empty(m + 1, indexA.type());
rusty1s's avatar
rusty1s committed
43
44
45
  cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k,
                   row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
  auto colA = indexA[1];
rusty1s's avatar
rusty1s committed
46
47
  cudaMemcpy(row_ptrA.data<int>() + m, &nnzA, sizeof(int),
             cudaMemcpyHostToDevice);
rusty1s's avatar
rusty1s committed
48

rusty1s's avatar
rusty1s committed
49
50
  // Convert B to CSR format.
  auto row_ptrB = at::empty(k + 1, indexB.type());
rusty1s's avatar
rusty1s committed
51
52
53
  cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k,
                   row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
  auto colB = indexB[1];
rusty1s's avatar
rusty1s committed
54
55
  cudaMemcpy(row_ptrB.data<int>() + k, &nnzB, sizeof(int),
             cudaMemcpyHostToDevice);
rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62

  cusparseMatDescr_t descr = 0;
  cusparseCreateMatDescr(&descr);
  cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
  cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);

  int nnzC;
rusty1s's avatar
rusty1s committed
63
  auto row_ptrC = at::empty(m + 1, indexB.type());
rusty1s's avatar
rusty1s committed
64
65
66
67
68
  cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
                      CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
                      row_ptrA.data<int>(), colA.data<int>(), descr, nnzB,
                      row_ptrB.data<int>(), colB.data<int>(), descr,
                      row_ptrC.data<int>(), &nnzC);
rusty1s's avatar
rusty1s committed
69
70
  auto colC = at::empty(nnzC, indexA.type());
  auto valueC = at::empty(nnzC, valueA.type());
rusty1s's avatar
rusty1s committed
71
72
73
74
75
76
77
78

  CSRGEMM(valueC.type(), cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
          CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
          valueA.data<scalar_t>(), row_ptrA.data<int>(), colA.data<int>(),
          descr, nnzB, valueB.data<scalar_t>(), row_ptrB.data<int>(),
          colB.data<int>(), descr, valueC.data<scalar_t>(),
          row_ptrC.data<int>(), colC.data<int>());

rusty1s's avatar
rusty1s committed
79
  auto rowC = at::empty(nnzC, indexA.type());
rusty1s's avatar
rusty1s committed
80
81
82
83
84
85
  cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m,
                   rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO);

  auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);

  return std::make_tuple(indexC, valueC);
rusty1s's avatar
to csr  
rusty1s committed
86
}