spmm_cpu.cpp 4.53 KB
Newer Older
rusty1s's avatar
matmul  
rusty1s committed
1
#include "spmm_cpu.h"
2

rusty1s's avatar
matmul  
rusty1s committed
3
4
#include "reducer.h"
#include "utils.h"
5

rusty1s's avatar
rusty1s committed
6
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
rusty1s's avatar
matmul  
rusty1s committed
7
8
9
spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
         torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
         std::string reduce) {
10
11
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
rusty1s's avatar
matmul  
rusty1s committed
12
13
  if (optional_value.has_value())
    CHECK_CPU(optional_value.value());
14
15
  CHECK_CPU(mat);

rusty1s's avatar
matmul  
rusty1s committed
16
17
18
19
20
21
22
  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  if (optional_value.has_value()) {
    CHECK_INPUT(optional_value.value().dim() == 1);
    CHECK_INPUT(optional_value.value().size(0) == col.size(0));
  }
  CHECK_INPUT(mat.dim() >= 2);
23

rusty1s's avatar
rusty1s committed
24
25
  mat = mat.contiguous();

26
27
  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
rusty1s's avatar
rusty1s committed
28
  auto out = torch::empty(sizes, mat.options());
29

rusty1s's avatar
rusty1s committed
30
  torch::optional<torch::Tensor> arg_out = torch::nullopt;
31
32
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
rusty1s's avatar
rusty1s committed
33
    arg_out = torch::full_like(out, col.numel(), rowptr.options());
rusty1s's avatar
matmul  
rusty1s committed
34
    arg_out_data = arg_out.value().data_ptr<int64_t>();
35
36
  }

rusty1s's avatar
matmul  
rusty1s committed
37
38
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
39

rusty1s's avatar
rusty1s committed
40
41
  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
42
  auto K = mat.size(-1);
rusty1s's avatar
rusty1s committed
43
  auto B = mat.numel() / (N * K);
44
45
46

  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
    scalar_t *value_data = nullptr;
rusty1s's avatar
matmul  
rusty1s committed
47
48
    auto mat_data = mat.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();
49
50
51

    scalar_t val;
    std::vector<scalar_t> vals(K);
rusty1s's avatar
rusty1s committed
52
    int64_t row_start, row_end, c;
53
54
55
    std::vector<int64_t> args(K);

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
rusty1s's avatar
matmul  
rusty1s committed
56
57
58
      AT_DISPATCH_HAS_VALUE(optional_value, [&] {
        if (HAS_VALUE) {
          value_data = optional_value.value().data_ptr<scalar_t>();
59
60
61
        }

        for (int b = 0; b < B; b++) {
rusty1s's avatar
rusty1s committed
62
63
          for (int m = 0; m < M; m++) {
            row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
64
65
66
67

            for (int k = 0; k < K; k++)
              vals[k] = Reducer<scalar_t, REDUCE>::init();

rusty1s's avatar
rusty1s committed
68
            int offset = b * N * K;
69
            for (int e = row_start; e < row_end; e++) {
rusty1s's avatar
rusty1s committed
70
              c = col_data[e];
rusty1s's avatar
matmul  
rusty1s committed
71
              if (HAS_VALUE)
72
73
                val = value_data[e];
              for (int k = 0; k < K; k++) {
rusty1s's avatar
matmul  
rusty1s committed
74
                if (HAS_VALUE)
75
                  Reducer<scalar_t, REDUCE>::update(
rusty1s's avatar
rusty1s committed
76
77
                      &vals[k], val * mat_data[offset + c * K + k], &args[k],
                      e);
78
79
                else
                  Reducer<scalar_t, REDUCE>::update(
rusty1s's avatar
rusty1s committed
80
                      &vals[k], mat_data[offset + c * K + k], &args[k], e);
81
82
              }
            }
rusty1s's avatar
rusty1s committed
83
            offset = b * M * K + m * K;
84
85
86
87
88
89
90
91
92
93
94
95
96
            for (int k = 0; k < K; k++)
              Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
                                               arg_out_data + offset + k,
                                               args[k], row_end - row_start);
          }
        }
      });
    });
  });

  return std::make_tuple(out, arg_out);
}

rusty1s's avatar
matmul  
rusty1s committed
97
98
99
torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
                                torch::Tensor col, torch::Tensor mat,
                                torch::Tensor grad, std::string reduce) {
rusty1s's avatar
rusty1s committed
100
  CHECK_CPU(row);
rusty1s's avatar
rusty1s committed
101
  CHECK_CPU(rowptr);
rusty1s's avatar
rusty1s committed
102
  CHECK_CPU(col);
rusty1s's avatar
rusty1s committed
103
104
105
106
  CHECK_CPU(mat);
  CHECK_CPU(grad);

  mat = mat.contiguous();
rusty1s's avatar
rusty1s committed
107
  grad = grad.contiguous();
rusty1s's avatar
rusty1s committed
108

rusty1s's avatar
rusty1s committed
109
  auto M = grad.size(-2);
rusty1s's avatar
rusty1s committed
110
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
111
  auto E = row.numel();
rusty1s's avatar
rusty1s committed
112
113
114
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);

rusty1s's avatar
rusty1s committed
115
  auto out = torch::zeros(row.numel(), grad.options());
rusty1s's avatar
rusty1s committed
116

rusty1s's avatar
matmul  
rusty1s committed
117
118
119
120
121
122
123
  auto row_data = row.data_ptr<int64_t>();
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_value_bw", [&] {
    auto mat_data = mat.data_ptr<scalar_t>();
    auto grad_data = grad.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();
rusty1s's avatar
rusty1s committed
124
125

    scalar_t val;
rusty1s's avatar
rusty1s committed
126
    int64_t row, col;
rusty1s's avatar
rusty1s committed
127
128
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (int b = 0; b < B; b++) {
rusty1s's avatar
rusty1s committed
129
        for (int e = 0; e < E; e++) {
rusty1s's avatar
rusty1s committed
130
          row = row_data[e], col = col_data[e], val = (scalar_t)0;
rusty1s's avatar
rusty1s committed
131
132
133
134
135
136
137
          for (int k = 0; k < K; k++) {
            val += mat_data[b * N * K + col * K + k] *
                   grad_data[b * M * K + row * K + k];
          }
          if (REDUCE == MEAN) {
            int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
            val /= (scalar_t)std::max(row_end - row_start, 1);
rusty1s's avatar
rusty1s committed
138
          }
rusty1s's avatar
rusty1s committed
139
          out_data[e] += val;
rusty1s's avatar
rusty1s committed
140
141
142
143
144
145
146
        }
      }
    });
  });

  return out;
}