torch_bindings.cpp 534 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <torch/extension.h>

torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
                        bool use_exllama, int64_t bit);

void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);

// Bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("gptq_gemm", &gptq_gemm, "make_q_matrix");
    m.def("gptq_shuffle", &gptq_shuffle, "gemm_half_q_half");
}