#include #include #include "test_ops.h" #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) #define TORCH_HAS_CUDA namespace test_ops_impl { // 操作符的具体实现 at::Tensor add_one(at::Tensor input) { return input + 1; } at::Tensor multiply_by_two(at::Tensor input) { return input * 2; } } // 在TORCH_LIBRARY_IMPL中注册CPU实现 TORCH_LIBRARY_IMPL_EXPAND(test_ops, CPU, cpu_ops) { cpu_ops.impl("add_one", &test_ops_impl::add_one); cpu_ops.impl("multiply_by_two", &test_ops_impl::multiply_by_two); } // 在TORCH_LIBRARY_IMPL中注册CUDA实现(如果有CUDA) #ifdef TORCH_HAS_CUDA TORCH_LIBRARY_IMPL_EXPAND(test_ops, CUDA, cuda_ops) { // 注意:这里假设CPU和CUDA使用相同的实现函数 // 如果CUDA需要不同的实现,可以定义专门的CUDA版本函数 cuda_ops.impl("add_one", &test_ops_impl::add_one); cuda_ops.impl("multiply_by_two", &test_ops_impl::multiply_by_two); } #endif