spspmm_kernel.cu 6.02 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
2
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
to csr  
rusty1s committed
4
5
#include <cusparse.h>

rusty1s's avatar
rusty1s committed
6
7
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
8
9
10
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(TYPE, ...)        \
  [&] {                                                                        \
    switch (TYPE) {                                                            \
rusty1s's avatar
rusty1s committed
11
    case torch::ScalarType::Float: {                                           \
rusty1s's avatar
rusty1s committed
12
13
14
15
16
      using scalar_t = float;                                                  \
      const auto &cusparsecsrgemm2_bufferSizeExt =                             \
          cusparseScsrgemm2_bufferSizeExt;                                     \
      return __VA_ARGS__();                                                    \
    }                                                                          \
rusty1s's avatar
rusty1s committed
17
    case torch::ScalarType::Double: {                                          \
rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
24
25
26
      using scalar_t = double;                                                 \
      const auto &cusparsecsrgemm2_bufferSizeExt =                             \
          cusparseDcsrgemm2_bufferSizeExt;                                     \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    default:                                                                   \
      AT_ERROR("Not implemented for '", toString(TYPE), "'");                  \
    }                                                                          \
  }()
rusty1s's avatar
rusty1s committed
27

rusty1s's avatar
rusty1s committed
28
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(TYPE, ...)                        \
rusty1s's avatar
rusty1s committed
29
  [&] {                                                                        \
rusty1s's avatar
rusty1s committed
30
    switch (TYPE) {                                                            \
rusty1s's avatar
rusty1s committed
31
    case torch::ScalarType::Float: {                                           \
rusty1s's avatar
rusty1s committed
32
      using scalar_t = float;                                                  \
rusty1s's avatar
rusty1s committed
33
34
      const auto &cusparsecsrgemm2 = cusparseScsrgemm2;                        \
      return __VA_ARGS__();                                                    \
rusty1s's avatar
rusty1s committed
35
    }                                                                          \
rusty1s's avatar
rusty1s committed
36
    case torch::ScalarType::Double: {                                          \
rusty1s's avatar
rusty1s committed
37
      using scalar_t = double;                                                 \
rusty1s's avatar
rusty1s committed
38
39
      const auto &cusparsecsrgemm2 = cusparseDcsrgemm2;                        \
      return __VA_ARGS__();                                                    \
rusty1s's avatar
rusty1s committed
40
41
    }                                                                          \
    default:                                                                   \
rusty1s's avatar
rusty1s committed
42
      AT_ERROR("Not implemented for '", toString(TYPE), "'");                  \
rusty1s's avatar
rusty1s committed
43
44
45
    }                                                                          \
  }()

rusty1s's avatar
rusty1s committed
46
47
48
49
50
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
            torch::optional<torch::Tensor> valueA, torch::Tensor rowptrB,
            torch::Tensor colB, torch::optional<torch::Tensor> valueB,
            int64_t M, int64_t N, int64_t K) {
rusty1s's avatar
rusty1s committed
51
52
53

  cudaSetDevice(rowptrA.get_device());

rusty1s's avatar
rusty1s committed
54
55
56
  cusparseMatDescr_t descr = 0;
  cusparseCreateMatDescr(&descr);
  auto handle = at::cuda::getCurrentCUDASparseHandle();
rusty1s's avatar
to csr  
rusty1s committed
57

rusty1s's avatar
rusty1s committed
58
59
  rowptrA = rowptrA.toType(torch::kInt), colA = colA.toType(torch::kInt);
  rowptrB = rowptrB.toType(torch::kInt), colB = colB.toType(torch::kInt);
rusty1s's avatar
to csr  
rusty1s committed
60

rusty1s's avatar
rusty1s committed
61
62
  auto rowptrA_data = rowptrA.DATA_PTR<int>(), colA_data = colA.DATA_PTR<int>();
  auto rowptrB_data = rowptrB.DATA_PTR<int>(), colB_data = colB.DATA_PTR<int>();
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
65
66
  csrgemm2Info_t info = NULL;
  cusparseCreateCsrgemm2Info(&info);

rusty1s's avatar
rusty1s committed
67
  auto scalar_type = torch::ScalarType::Float;
rusty1s's avatar
rusty1s committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
  if (valueA.has_value())
    scalar_type = valueA.value().scalar_type();
  if (valueB.has_value())
    scalar_type = valueB.value().scalar_type();

  size_t bufferSize;
  AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(scalar_type, [&] {
    scalar_t alpha = (scalar_t)1;
    cusparsecsrgemm2_bufferSizeExt(handle, M, N, K, &alpha, descr, colA.numel(),
                                   rowptrA_data, colA_data, descr, colB.numel(),
                                   rowptrB_data, colB_data, NULL, descr, 0,
                                   NULL, NULL, info, &bufferSize);
  });

  void *buffer = NULL;
  cudaMalloc(&buffer, bufferSize);
rusty1s's avatar
rusty1s committed
84
85

  int nnzC;
rusty1s's avatar
rusty1s committed
86
  auto rowptrC = torch::empty(M + 1, rowptrA.options());
rusty1s's avatar
rusty1s committed
87
88
89
90
91
  auto rowptrC_data = rowptrC.DATA_PTR<int>();
  cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data,
                       colA_data, descr, colB.numel(), rowptrB_data, colB_data,
                       descr, 0, NULL, NULL, descr, rowptrC_data, &nnzC, info,
                       buffer);
rusty1s's avatar
rusty1s committed
92

rusty1s's avatar
rusty1s committed
93
  auto colC = torch::empty(nnzC, colA.options());
rusty1s's avatar
rusty1s committed
94
  auto colC_data = colC.DATA_PTR<int>();
rusty1s's avatar
rusty1s committed
95

rusty1s's avatar
rusty1s committed
96
  if (!valueA.has_value() && valueB.has_value())
rusty1s's avatar
rusty1s committed
97
    valueA = torch::ones_like(valueB.value());
rusty1s's avatar
rusty1s committed
98

rusty1s's avatar
rusty1s committed
99
  if (!valueB.has_value() && valueA.has_value())
rusty1s's avatar
rusty1s committed
100
    valueB = torch::ones_like(valueA.value());
rusty1s's avatar
rusty1s committed
101

rusty1s's avatar
rusty1s committed
102
  torch::optional<torch::Tensor> valueC = torch::nullopt;
rusty1s's avatar
rusty1s committed
103
  if (valueA.has_value())
rusty1s's avatar
rusty1s committed
104
    valueC = torch::empty(nnzC, valueA.value().options());
rusty1s's avatar
rusty1s committed
105

rusty1s's avatar
rusty1s committed
106
107
  AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(scalar_type, [&] {
    scalar_t alpha = (scalar_t)1;
rusty1s's avatar
rusty1s committed
108

rusty1s's avatar
rusty1s committed
109
110
111
    scalar_t *valueA_data = NULL;
    if (valueA.has_value())
      valueA_data = valueA.value().DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
112

rusty1s's avatar
rusty1s committed
113
114
115
    scalar_t *valueB_data = NULL;
    if (valueB.has_value())
      valueB_data = valueB.value().DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
116

rusty1s's avatar
rusty1s committed
117
118
119
120
121
122
123
124
    scalar_t *valueC_data = NULL;
    if (valueC.has_value())
      valueC_data = valueC.value().DATA_PTR<scalar_t>();

    cusparsecsrgemm2(handle, M, N, K, &alpha, descr, colA.numel(), valueA_data,
                     rowptrA_data, colA_data, descr, colB.numel(), valueB_data,
                     rowptrB_data, colB_data, NULL, descr, 0, NULL, NULL, NULL,
                     descr, valueC_data, rowptrC_data, colC_data, info, buffer);
rusty1s's avatar
rusty1s committed
125
126
  });

rusty1s's avatar
rusty1s committed
127
128
  rowptrC = rowptrC.toType(torch::kLong);
  colC = colC.toType(torch::kLong);
rusty1s's avatar
rusty1s committed
129
130

  return std::make_tuple(rowptrC, colC, valueC);
rusty1s's avatar
rusty1s committed
131
}