grouped_gemm_ck_pybind.cu 412 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-License-Identifier: MIT

#include "grouped_gemm_ck.h"
#include "rocm_ops.hpp"
#include <pybind11/stl.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("ck_grouped_gemm", &ck_grouped_gemm, py::arg("a_tensors"), py::arg("b_tensors"));
    m.def("ck_grouped_gemm_out",
          &ck_grouped_gemm_out,
          py::arg("a_tensors"),
          py::arg("b_tensors"),
          py::arg("c_tensors"));
}