spmm_cpu.cpp 4.84 KB
Newer Older
rusty1s's avatar
matmul  
rusty1s committed
1
#include "spmm_cpu.h"
2

rusty1s's avatar
rusty1s committed
3
4
#include <ATen/Parallel.h>

rusty1s's avatar
matmul  
rusty1s committed
5
6
#include "reducer.h"
#include "utils.h"
7

rusty1s's avatar
rusty1s committed
8
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
rusty1s's avatar
matmul  
rusty1s committed
9
10
11
spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
         torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
         std::string reduce) {
12
13
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
rusty1s's avatar
matmul  
rusty1s committed
14
15
  if (optional_value.has_value())
    CHECK_CPU(optional_value.value());
16
17
  CHECK_CPU(mat);

rusty1s's avatar
matmul  
rusty1s committed
18
19
20
21
22
23
24
  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  if (optional_value.has_value()) {
    CHECK_INPUT(optional_value.value().dim() == 1);
    CHECK_INPUT(optional_value.value().size(0) == col.size(0));
  }
  CHECK_INPUT(mat.dim() >= 2);
25

rusty1s's avatar
rusty1s committed
26
27
  mat = mat.contiguous();

28
29
  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
rusty1s's avatar
rusty1s committed
30
  auto out = torch::empty(sizes, mat.options());
31

rusty1s's avatar
rusty1s committed
32
  torch::optional<torch::Tensor> arg_out = torch::nullopt;
33
34
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
rusty1s's avatar
rusty1s committed
35
    arg_out = torch::full_like(out, col.numel(), rowptr.options());
rusty1s's avatar
matmul  
rusty1s committed
36
    arg_out_data = arg_out.value().data_ptr<int64_t>();
37
38
  }

rusty1s's avatar
matmul  
rusty1s committed
39
40
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
41

rusty1s's avatar
rusty1s committed
42
43
  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
44
  auto K = mat.size(-1);
rusty1s's avatar
rusty1s committed
45
  auto B = mat.numel() / (N * K);
46

rusty1s's avatar
rusty1s committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
    scalar_t *value_data = nullptr;
    auto mat_data = mat.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      AT_DISPATCH_HAS_VALUE(optional_value, [&] {
        if (HAS_VALUE) {
          value_data = optional_value.value().data_ptr<scalar_t>();
        }

        int64_t grain_size = at::internal::GRAIN_SIZE /
                             (K * std::max(col.numel() / M, (int64_t)1));
        at::parallel_for(0, B * M, grain_size, [&](int64_t begin, int64_t end) {
          scalar_t val;
          std::vector<scalar_t> vals(K);
          int64_t row_start, row_end, b, m, c;
          std::vector<int64_t> args(K);

          for (auto i = begin; i < end; i++) {
            b = i / M, m = i % M;

            row_start = rowptr_data[m], row_end = rowptr_data[m + 1];

            for (auto k = 0; k < K; k++)
              vals[k] = Reducer<scalar_t, REDUCE>::init();

            auto offset = b * N * K;
            for (auto e = row_start; e < row_end; e++) {
              c = col_data[e];
              if (HAS_VALUE)
                val = value_data[e];
              for (auto k = 0; k < K; k++) {
                if (HAS_VALUE)
                  Reducer<scalar_t, REDUCE>::update(
                      &vals[k], val * mat_data[offset + c * K + k], &args[k],
                      e);
                else
                  Reducer<scalar_t, REDUCE>::update(
                      &vals[k], mat_data[offset + c * K + k], &args[k], e);
              }
88
            }
rusty1s's avatar
rusty1s committed
89
90
91
92
93
94
            offset = b * M * K + m * K;
            for (auto k = 0; k < K; k++)
              Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
                                               arg_out_data + offset + k,
                                               args[k], row_end - row_start);
          }
rusty1s's avatar
rusty1s committed
95
        });
96
      });
rusty1s's avatar
rusty1s committed
97
98
    });
  });
99
100
101
102

  return std::make_tuple(out, arg_out);
}

rusty1s's avatar
matmul  
rusty1s committed
103
104
105
torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
                                torch::Tensor col, torch::Tensor mat,
                                torch::Tensor grad, std::string reduce) {
rusty1s's avatar
rusty1s committed
106
  CHECK_CPU(row);
rusty1s's avatar
rusty1s committed
107
  CHECK_CPU(rowptr);
rusty1s's avatar
rusty1s committed
108
  CHECK_CPU(col);
rusty1s's avatar
rusty1s committed
109
110
111
112
  CHECK_CPU(mat);
  CHECK_CPU(grad);

  mat = mat.contiguous();
rusty1s's avatar
rusty1s committed
113
  grad = grad.contiguous();
rusty1s's avatar
rusty1s committed
114

rusty1s's avatar
rusty1s committed
115
  auto M = grad.size(-2);
rusty1s's avatar
rusty1s committed
116
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
117
  auto E = row.numel();
rusty1s's avatar
rusty1s committed
118
119
120
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);

Matthias Fey's avatar
Matthias Fey committed
121
  auto out = torch::zeros({row.numel()}, grad.options());
rusty1s's avatar
rusty1s committed
122

rusty1s's avatar
matmul  
rusty1s committed
123
124
125
  auto row_data = row.data_ptr<int64_t>();
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
    auto mat_data = mat.data_ptr<scalar_t>();
    auto grad_data = grad.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    scalar_t val;
    int64_t row, col;
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (int b = 0; b < B; b++) {
        for (int e = 0; e < E; e++) {
          row = row_data[e], col = col_data[e], val = (scalar_t)0;
          for (int k = 0; k < K; k++) {
            val += mat_data[b * N * K + col * K + k] *
                   grad_data[b * M * K + row * K + k];
rusty1s's avatar
rusty1s committed
140
          }
rusty1s's avatar
rusty1s committed
141
142
143
144
145
146
147
148
149
          if (REDUCE == MEAN) {
            int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
            val /= (scalar_t)std::max(row_end - row_start, 1);
          }
          out_data[e] += val;
        }
      }
    });
  });
rusty1s's avatar
rusty1s committed
150
151
152

  return out;
}