batched_gemm_bf16_tune_pybind.cu 303 Bytes
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
// SPDX-License-Identifier: MIT
 
#include "batched_gemm_bf16.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("batched_gemm_bf16_tune", &batched_gemm_bf16_tune, "batched_gemm_bf16_tune", py::arg("XQ"), py::arg("WQ"),
          py::arg("Out"), py::arg("kernelId") = 0, py::arg("splitK") = 0);
}