#include #include "cpu/scatter_cpu.h" #include "utils.h" #ifdef WITH_CUDA #include "cuda/scatter_cuda.h" #endif std::tuple> scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim, torch::optional optional_out, torch::optional dim_size, std::string reduce) { if (src.device().is_cuda()) { #ifdef WITH_CUDA return scatter_cuda(src, index, dim, optional_out, dim_size, reduce); #else AT_ERROR("Not compiled with CUDA support"); #endif } else { return scatter_cpu(src, index, dim, optional_out, dim_size, reduce); } } static auto registry = torch::RegisterOperators().op("torch_scatter::scatter_fw", &scatter_fw);