matmul_cuda.cu 1.55 KB
Newer Older
rusty1s's avatar
to csr  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <ATen/ATen.h>

#include <cusparse.h>

static cusparseHandle_t cusparse_handle = 0;

static void init_cusparse() {
  if (cusparse_handle == 0) {
    cusparseStatus_t status = cusparseCreate(&cusparse_handle);
  }
}

at::Tensor spspmm_cuda(at::Tensor matrix1, at::Tensor matrix2) {
  init_cusparse();

  auto nnz = matrix1._nnz();
  auto inDim = matrix1.size(1);

  auto row = matrix1._indices()[0].toType(at::kInt);
  auto row_ptrs = at::empty(row.type(), {inDim + 1});

  cusparseXcoo2csr(cusparse_handle, row.data<int>(), nnz, inDim,
                   row_ptrs.data<int>(), CUSPARSE_INDEX_BASE_ZERO);

  printf("%lli\n", nnz);
  printf("%lli\n", inDim);

  /* colbuf at::empty(nnz); */
  /* auto colPtrs = at::empty(inDim + 1, at::kInt); */

  /* auto row = matrix1._indices(); */
  /* for (int i = 0; i < 5; i++) { */
  /*   row_buf.data<int>()[i] = (int)row.data<int64_t>()[i]; */
  /* } */
  /* printf("%lli\n", row.numel()); */

  return matrix1;
}
/* #include <ATen/SparseTensorImpl.h> */

/* namespace at { */
/* namespace native { */
/* using SparseTensor = Tensor; */

/* namespace { */
/* at::SparseTensor spspmm_cuda(at::SparseTensor matrix1, */
/*                              at::SparseTensor matrix2) { */

/*   return matrix1; */
/* } */
/* } // namespace */
/* } // namespace native */
/* } // namespace at */

// defined in aten/src/THCUNN/SparseLinear.cu as

/* cusparseXcoo2csr(cusparse_handle, THCudaIntTensor_data(state, colbuf), nnz,
 */
/*                  inDim, THCudaIntTensor_data(state, colPtrs), */
/*                  CUSPARSE_INDEX_BASE_ONE); */