#include #include "compat.h" #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") enum ReductionType { SUM, MEAN, MIN, MAX }; const std::map reduce2REDUCE = { {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX}, }; #define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ [&] { \ switch (reduce2REDUCE.at(reduce)) { \ case SUM: { \ const ReductionType REDUCE = SUM; \ return __VA_ARGS__(); \ } \ case MEAN: { \ const ReductionType REDUCE = MEAN; \ return __VA_ARGS__(); \ } \ case MIN: { \ const ReductionType REDUCE = MIN; \ return __VA_ARGS__(); \ } \ case MAX: { \ const ReductionType REDUCE = MAX; \ return __VA_ARGS__(); \ } \ } \ }() #define AT_DISPATCH_HAS_VAL(value_opt, ...) \ [&] { \ switch (value_opt.has_value()) { \ case true: { \ const bool HAS_VAL = true; \ return __VA_ARGS__(); \ } \ case false: { \ const bool HAS_VAL = false; \ return __VA_ARGS__(); \ } \ } \ }() template struct Reducer { static inline scalar_t init() { if (REDUCE == MIN) { return std::numeric_limits::max(); } else if (REDUCE == MAX) { return std::numeric_limits::lowest(); } else { return (scalar_t)0; } } static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, int64_t new_arg) { if (REDUCE == SUM || REDUCE == MEAN) { *val = *val + new_val; } else if ((REDUCE == MIN && new_val < *val) || (REDUCE == MAX && new_val > *val)) { *val = new_val; *arg = new_arg; } } static inline void write(scalar_t *address, scalar_t val, int64_t *arg_address, int64_t arg, int count) { if (REDUCE == SUM) { *address = val; } else if (REDUCE == MEAN) { *address = val / (count > 0 ? count : (scalar_t)1); } else if (REDUCE == MIN || REDUCE == MAX) { if (count > 0) { *address = val; *arg_address = arg; } else { *address = (scalar_t)0; } } } }; std::tuple> spmm(at::Tensor rowptr, at::Tensor col, at::optional value_opt, at::Tensor mat, std::string reduce) { CHECK_CPU(rowptr); CHECK_CPU(col); if (value_opt.has_value()) CHECK_CPU(value_opt.value()); CHECK_CPU(mat); mat = mat.contiguous(); AT_ASSERTM(rowptr.dim() == 1, "Input mismatch"); AT_ASSERTM(col.dim() == 1, "Input mismatch"); if (value_opt.has_value()) AT_ASSERTM(value_opt.value().dim() == 1); AT_ASSERTM(mat.dim() >= 2, "Input mismatch"); AT_ASSERTM(rowptr.numel() - 1 == mat.size(-2), "Input mismatch"); auto sizes = mat.sizes().vec(); sizes[mat.dim() - 2] = rowptr.numel() - 1; auto out = at::empty(sizes, mat.options()); at::optional arg_out = at::nullopt; int64_t *arg_out_data = nullptr; if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { arg_out = at::full_like(out, mat.size(-2), rowptr.options()); arg_out_data = arg_out.value().DATA_PTR(); } auto rowptr_data = rowptr.DATA_PTR(); auto col_data = col.DATA_PTR(); int N = rowptr.numel() - 1; int M = mat.size(-2); int K = mat.size(-1); int B = mat.numel() / (M * K); AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] { scalar_t *value_data = nullptr; auto mat_data = out.DATA_PTR(); auto out_data = mat.DATA_PTR(); scalar_t val; std::vector vals(K); int64_t row_start, row_end, col_idx; std::vector args(K); AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_HAS_VAL(value_opt, [&] { if (HAS_VAL) { value_data = value_opt.value().DATA_PTR(); } for (int b = 0; b < B; b++) { for (int n = 0; n < N; n++) { row_start = rowptr_data[n], row_end = rowptr_data[n + 1]; for (int k = 0; k < K; k++) vals[k] = Reducer::init(); int offset = b * M * K; for (int e = row_start; e < row_end; e++) { col_idx = col_data[e]; if (HAS_VAL) val = value_data[e]; for (int k = 0; k < K; k++) { if (HAS_VAL) Reducer::update( &vals[k], val * mat_data[offset + col_idx * K + k], &args[k], e); else Reducer::update( &vals[k], mat_data[offset + col_idx * K + k], &args[k], e); } } offset = b * N * K + n * K; for (int k = 0; k < K; k++) Reducer::write(out_data + offset + k, vals[k], arg_out_data + offset + k, args[k], row_end - row_start); } } }); }); }); return std::make_tuple(out, arg_out); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("spmm", &spmm, "Sparse-Dense Matrix Multiplication (CPU)"); }