// SPDX-License-Identifier: MIT #include "grouped_gemm_ck.h" #include "rocm_ops.hpp" #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ck_grouped_gemm", &ck_grouped_gemm, py::arg("a_tensors"), py::arg("b_tensors")); m.def("ck_grouped_gemm_out", &ck_grouped_gemm_out, py::arg("a_tensors"), py::arg("b_tensors"), py::arg("c_tensors")); }