spmm_cpu.cpp 4.76 KB
Newer Older
rusty1s's avatar
matmul  
rusty1s committed
1
#include "spmm_cpu.h"
2

rusty1s's avatar
rusty1s committed
3
4
#include <ATen/Parallel.h>

rusty1s's avatar
matmul  
rusty1s committed
5
6
#include "reducer.h"
#include "utils.h"
7

rusty1s's avatar
rusty1s committed
8
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
rusty1s's avatar
matmul  
rusty1s committed
9
10
11
spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
         torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
         std::string reduce) {
12
13
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
rusty1s's avatar
matmul  
rusty1s committed
14
15
  if (optional_value.has_value())
    CHECK_CPU(optional_value.value());
16
17
  CHECK_CPU(mat);

rusty1s's avatar
matmul  
rusty1s committed
18
19
20
21
22
23
24
  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  if (optional_value.has_value()) {
    CHECK_INPUT(optional_value.value().dim() == 1);
    CHECK_INPUT(optional_value.value().size(0) == col.size(0));
  }
  CHECK_INPUT(mat.dim() >= 2);
25

rusty1s's avatar
rusty1s committed
26
27
  mat = mat.contiguous();

28
29
  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
rusty1s's avatar
rusty1s committed
30
  auto out = torch::empty(sizes, mat.options());
31

rusty1s's avatar
rusty1s committed
32
  torch::optional<torch::Tensor> arg_out = torch::nullopt;
33
34
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
rusty1s's avatar
rusty1s committed
35
    arg_out = torch::full_like(out, col.numel(), rowptr.options());
rusty1s's avatar
matmul  
rusty1s committed
36
    arg_out_data = arg_out.value().data_ptr<int64_t>();
37
38
  }

rusty1s's avatar
matmul  
rusty1s committed
39
40
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
41

rusty1s's avatar
rusty1s committed
42
43
  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
44
  auto K = mat.size(-1);
rusty1s's avatar
rusty1s committed
45
  auto B = mat.numel() / (N * K);
46
47
48

  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
    scalar_t *value_data = nullptr;
rusty1s's avatar
matmul  
rusty1s committed
49
50
    auto mat_data = mat.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();
51
52

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
rusty1s's avatar
matmul  
rusty1s committed
53
54
55
      AT_DISPATCH_HAS_VALUE(optional_value, [&] {
        if (HAS_VALUE) {
          value_data = optional_value.value().data_ptr<scalar_t>();
56
57
        }

rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
64
65
66
67
        int64_t grain_size = at::internal::GRAIN_SIZE / (K * (col.numel() / M));
        at::parallel_for(0, B * M, grain_size, [&](int64_t begin, int64_t end) {
          scalar_t val;
          std::vector<scalar_t> vals(K);
          int64_t row_start, row_end, b, m, c;
          std::vector<int64_t> args(K);

          for (auto i = begin; i < end; i++) {
            b = i / M, m = i % M;

rusty1s's avatar
rusty1s committed
68
            row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
69

rusty1s's avatar
rusty1s committed
70
            for (auto k = 0; k < K; k++)
rusty1s's avatar
rusty1s committed
71
              vals[k] = Reducer<scalar_t, REDUCE>::init();
72

rusty1s's avatar
rusty1s committed
73
74
            auto offset = b * N * K;
            for (auto e = row_start; e < row_end; e++) {
rusty1s's avatar
rusty1s committed
75
              c = col_data[e];
rusty1s's avatar
matmul  
rusty1s committed
76
              if (HAS_VALUE)
77
                val = value_data[e];
rusty1s's avatar
rusty1s committed
78
              for (auto k = 0; k < K; k++) {
rusty1s's avatar
matmul  
rusty1s committed
79
                if (HAS_VALUE)
rusty1s's avatar
rusty1s committed
80
81
82
                  Reducer<scalar_t, REDUCE>::update(
                      &vals[k], val * mat_data[offset + c * K + k], &args[k],
                      e);
83
                else
rusty1s's avatar
rusty1s committed
84
85
                  Reducer<scalar_t, REDUCE>::update(
                      &vals[k], mat_data[offset + c * K + k], &args[k], e);
86
87
              }
            }
rusty1s's avatar
rusty1s committed
88
            offset = b * M * K + m * K;
rusty1s's avatar
rusty1s committed
89
            for (auto k = 0; k < K; k++)
rusty1s's avatar
rusty1s committed
90
91
92
              Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
                                               arg_out_data + offset + k,
                                               args[k], row_end - row_start);
93
          }
rusty1s's avatar
rusty1s committed
94
        });
95
96
97
98
99
100
101
      });
    });
  });

  return std::make_tuple(out, arg_out);
}

rusty1s's avatar
matmul  
rusty1s committed
102
103
104
torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
                                torch::Tensor col, torch::Tensor mat,
                                torch::Tensor grad, std::string reduce) {
rusty1s's avatar
rusty1s committed
105
  CHECK_CPU(row);
rusty1s's avatar
rusty1s committed
106
  CHECK_CPU(rowptr);
rusty1s's avatar
rusty1s committed
107
  CHECK_CPU(col);
rusty1s's avatar
rusty1s committed
108
109
110
111
  CHECK_CPU(mat);
  CHECK_CPU(grad);

  mat = mat.contiguous();
rusty1s's avatar
rusty1s committed
112
  grad = grad.contiguous();
rusty1s's avatar
rusty1s committed
113

rusty1s's avatar
rusty1s committed
114
  auto M = grad.size(-2);
rusty1s's avatar
rusty1s committed
115
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
116
  auto E = row.numel();
rusty1s's avatar
rusty1s committed
117
118
119
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);

rusty1s's avatar
rusty1s committed
120
  auto out = torch::zeros(row.numel(), grad.options());
rusty1s's avatar
rusty1s committed
121

rusty1s's avatar
matmul  
rusty1s committed
122
123
124
125
126
127
128
  auto row_data = row.data_ptr<int64_t>();
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_value_bw", [&] {
    auto mat_data = mat.data_ptr<scalar_t>();
    auto grad_data = grad.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();
rusty1s's avatar
rusty1s committed
129
130

    scalar_t val;
rusty1s's avatar
rusty1s committed
131
    int64_t row, col;
rusty1s's avatar
rusty1s committed
132
133
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (int b = 0; b < B; b++) {
rusty1s's avatar
rusty1s committed
134
        for (int e = 0; e < E; e++) {
rusty1s's avatar
rusty1s committed
135
          row = row_data[e], col = col_data[e], val = (scalar_t)0;
rusty1s's avatar
rusty1s committed
136
137
138
139
140
141
142
          for (int k = 0; k < K; k++) {
            val += mat_data[b * N * K + col * K + k] *
                   grad_data[b * M * K + row * K + k];
          }
          if (REDUCE == MEAN) {
            int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
            val /= (scalar_t)std::max(row_end - row_start, 1);
rusty1s's avatar
rusty1s committed
143
          }
rusty1s's avatar
rusty1s committed
144
          out_data[e] += val;
rusty1s's avatar
rusty1s committed
145
146
147
148
149
150
151
        }
      }
    });
  });

  return out;
}