spmm_cpu.cpp 5.25 KB
Newer Older
rusty1s's avatar
matmul  
rusty1s committed
1
#include "spmm_cpu.h"
2

rusty1s's avatar
rusty1s committed
3
4
#include <ATen/Parallel.h>

rusty1s's avatar
matmul  
rusty1s committed
5
6
#include "reducer.h"
#include "utils.h"
7

rusty1s's avatar
rusty1s committed
8
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
rusty1s's avatar
matmul  
rusty1s committed
9
10
11
spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
         torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
         std::string reduce) {
12
13
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
rusty1s's avatar
matmul  
rusty1s committed
14
15
  if (optional_value.has_value())
    CHECK_CPU(optional_value.value());
16
17
  CHECK_CPU(mat);

rusty1s's avatar
matmul  
rusty1s committed
18
19
20
21
22
23
24
  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  if (optional_value.has_value()) {
    CHECK_INPUT(optional_value.value().dim() == 1);
    CHECK_INPUT(optional_value.value().size(0) == col.size(0));
  }
  CHECK_INPUT(mat.dim() >= 2);
25

rusty1s's avatar
rusty1s committed
26
27
  mat = mat.contiguous();

28
29
  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
rusty1s's avatar
rusty1s committed
30
  auto out = torch::empty(sizes, mat.options());
31

rusty1s's avatar
rusty1s committed
32
  torch::optional<torch::Tensor> arg_out = torch::nullopt;
33
34
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
rusty1s's avatar
rusty1s committed
35
    arg_out = torch::full_like(out, col.numel(), rowptr.options());
rusty1s's avatar
matmul  
rusty1s committed
36
    arg_out_data = arg_out.value().data_ptr<int64_t>();
37
38
  }

rusty1s's avatar
matmul  
rusty1s committed
39
40
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
41

rusty1s's avatar
rusty1s committed
42
43
  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
44
  auto K = mat.size(-1);
rusty1s's avatar
rusty1s committed
45
  auto B = mat.numel() / (N * K);
46

rusty1s's avatar
rusty1s committed
47
48
49
50
51
52
53
54
55
56
  AT_DISPATCH_ALL_TYPES_AND(
      at::ScalarType::Half, mat.scalar_type(), "spmm", [&] {
        scalar_t *value_data = nullptr;
        auto mat_data = mat.data_ptr<scalar_t>();
        auto out_data = out.data_ptr<scalar_t>();

        AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
          AT_DISPATCH_HAS_VALUE(optional_value, [&] {
            if (HAS_VALUE) {
              value_data = optional_value.value().data_ptr<scalar_t>();
57
            }
rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

            int64_t grain_size =
                at::internal::GRAIN_SIZE / (K * (col.numel() / M));
            at::parallel_for(
                0, B * M, grain_size, [&](int64_t begin, int64_t end) {
                  scalar_t val;
                  std::vector<scalar_t> vals(K);
                  int64_t row_start, row_end, b, m, c;
                  std::vector<int64_t> args(K);

                  for (auto i = begin; i < end; i++) {
                    b = i / M, m = i % M;

                    row_start = rowptr_data[m], row_end = rowptr_data[m + 1];

                    for (auto k = 0; k < K; k++)
                      vals[k] = Reducer<scalar_t, REDUCE>::init();

                    auto offset = b * N * K;
                    for (auto e = row_start; e < row_end; e++) {
                      c = col_data[e];
                      if (HAS_VALUE)
                        val = value_data[e];
                      for (auto k = 0; k < K; k++) {
                        if (HAS_VALUE)
                          Reducer<scalar_t, REDUCE>::update(
                              &vals[k], val * mat_data[offset + c * K + k],
                              &args[k], e);
                        else
                          Reducer<scalar_t, REDUCE>::update(
                              &vals[k], mat_data[offset + c * K + k], &args[k],
                              e);
                      }
                    }
                    offset = b * M * K + m * K;
                    for (auto k = 0; k < K; k++)
                      Reducer<scalar_t, REDUCE>::write(
                          out_data + offset + k, vals[k],
                          arg_out_data + offset + k, args[k],
                          row_end - row_start);
                  }
                });
          });
rusty1s's avatar
rusty1s committed
101
        });
102
103
104
105
106
      });

  return std::make_tuple(out, arg_out);
}

rusty1s's avatar
matmul  
rusty1s committed
107
108
109
torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
                                torch::Tensor col, torch::Tensor mat,
                                torch::Tensor grad, std::string reduce) {
rusty1s's avatar
rusty1s committed
110
  CHECK_CPU(row);
rusty1s's avatar
rusty1s committed
111
  CHECK_CPU(rowptr);
rusty1s's avatar
rusty1s committed
112
  CHECK_CPU(col);
rusty1s's avatar
rusty1s committed
113
114
115
116
  CHECK_CPU(mat);
  CHECK_CPU(grad);

  mat = mat.contiguous();
rusty1s's avatar
rusty1s committed
117
  grad = grad.contiguous();
rusty1s's avatar
rusty1s committed
118

rusty1s's avatar
rusty1s committed
119
  auto M = grad.size(-2);
rusty1s's avatar
rusty1s committed
120
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
121
  auto E = row.numel();
rusty1s's avatar
rusty1s committed
122
123
124
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);

rusty1s's avatar
rusty1s committed
125
  auto out = torch::zeros(row.numel(), grad.options());
rusty1s's avatar
rusty1s committed
126

rusty1s's avatar
matmul  
rusty1s committed
127
128
129
  auto row_data = row.data_ptr<int64_t>();
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
  AT_DISPATCH_ALL_TYPES_AND(
      at::ScalarType::Half, mat.scalar_type(), "spmm_value_bw", [&] {
        auto mat_data = mat.data_ptr<scalar_t>();
        auto grad_data = grad.data_ptr<scalar_t>();
        auto out_data = out.data_ptr<scalar_t>();

        scalar_t val;
        int64_t row, col;
        AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
          for (int b = 0; b < B; b++) {
            for (int e = 0; e < E; e++) {
              row = row_data[e], col = col_data[e], val = (scalar_t)0;
              for (int k = 0; k < K; k++) {
                val += mat_data[b * N * K + col * K + k] *
                       grad_data[b * M * K + row * K + k];
              }
              if (REDUCE == MEAN) {
                int row_start = rowptr_data[row],
                    row_end = rowptr_data[row + 1];
                val /= (scalar_t)std::max(row_end - row_start, 1);
              }
              out_data[e] += val;
            }
rusty1s's avatar
rusty1s committed
153
          }
rusty1s's avatar
rusty1s committed
154
155
        });
      });
rusty1s's avatar
rusty1s committed
156
157
158

  return out;
}