#include #include using torch::Tensor; namespace at { namespace native { void rms_rotary_embedding_fuse( Tensor& positions, Tensor& query, Tensor& key, int64_t head_size, Tensor& cos_sin_cache, bool is_neox, Tensor weight_q, Tensor weight_k, std::optional residual_q, std::optional residual_k, double epsilon); } // namespace native } // namespace at PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rms_rotary_embedding_fuse", &at::native::rms_rotary_embedding_fuse, "rms_rotary_embedding_fuse"); }