#pragma once #include "../extensions.h" std::tuple> spmm_cuda(torch::Tensor rowptr, torch::Tensor col, torch::optional optional_value, torch::Tensor mat, std::string reduce); torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr, torch::Tensor col, torch::Tensor mat, torch::Tensor grad, std::string reduce); template __device__ T __ldg(const T* ptr) { return *ptr; }