pybind.cpp 3.31 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
#include "gemm.h"
muyangli's avatar
muyangli committed
2
#include "gemm88.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
3
#include "flux.h"
muyangli's avatar
muyangli committed
4
5
6
#include "sana.h"
#include "ops.h"
#include "utils.h"
Samuel Tesfai's avatar
Samuel Tesfai committed
7
#include <torch/extension.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9
10
11
12
13
14

#include <pybind11/pybind11.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
        .def(py::init<>())
        .def("init", &QuantizedFluxModel::init,
15
            py::arg("use_fp4"),
muyangli's avatar
muyangli committed
16
            py::arg("offload"),
Zhekai Zhang's avatar
Zhekai Zhang committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
            py::arg("bf16"),
            py::arg("deviceId")
        )
        .def("reset", &QuantizedFluxModel::reset)
        .def("load", &QuantizedFluxModel::load, 
            py::arg("path"),
            py::arg("partial") = false
        )
        .def("forward", &QuantizedFluxModel::forward)
        .def("forward_layer", &QuantizedFluxModel::forward_layer)
        .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
        .def("startDebug", &QuantizedFluxModel::startDebug)
        .def("stopDebug", &QuantizedFluxModel::stopDebug)
        .def("getDebugResults", &QuantizedFluxModel::getDebugResults)
        .def("setLoraScale", &QuantizedFluxModel::setLoraScale)
32
        .def("forceFP16Attention", &QuantizedFluxModel::forceFP16Attention)
Zhekai Zhang's avatar
Zhekai Zhang committed
33
    ;
muyangli's avatar
muyangli committed
34
35
36
37
38
    py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
        .def(py::init<>())
        .def("init", &QuantizedSanaModel::init,
            py::arg("config"),
            py::arg("pag_layers"),
39
            py::arg("use_fp4"),
muyangli's avatar
muyangli committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
            py::arg("bf16"),
            py::arg("deviceId")
        )
        .def("reset", &QuantizedSanaModel::reset)
        .def("load", &QuantizedSanaModel::load, 
            py::arg("path"),
            py::arg("partial") = false
        )
        .def("forward", &QuantizedSanaModel::forward)
        .def("forward_layer", &QuantizedSanaModel::forward_layer)
        .def("startDebug", &QuantizedSanaModel::startDebug)
        .def("stopDebug", &QuantizedSanaModel::stopDebug)
        .def("getDebugResults", &QuantizedSanaModel::getDebugResults)
    ;
Zhekai Zhang's avatar
Zhekai Zhang committed
54
55
56
57
58
59
60
61
62
63
64
    py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
        .def(py::init<>())
        .def("init", &QuantizedGEMM::init)
        .def("reset", &QuantizedGEMM::reset)
        .def("load", &QuantizedGEMM::load)
        .def("forward", &QuantizedGEMM::forward)
        .def("quantize", &QuantizedGEMM::quantize)
        .def("startDebug", &QuantizedGEMM::startDebug)
        .def("stopDebug", &QuantizedGEMM::stopDebug)
        .def("getDebugResults", &QuantizedGEMM::getDebugResults)
    ;
muyangli's avatar
muyangli committed
65
66
67
68
69
70
71
72
73
74
75
76
    py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88")
        .def(py::init<>())
        .def("init", &QuantizedGEMM88::init)
        .def("reset", &QuantizedGEMM88::reset)
        .def("load", &QuantizedGEMM88::load)
        .def("forward", &QuantizedGEMM88::forward)
        .def("startDebug", &QuantizedGEMM88::startDebug)
        .def("stopDebug", &QuantizedGEMM88::stopDebug)
        .def("getDebugResults", &QuantizedGEMM88::getDebugResults)
    ;

    m.def_submodule("ops")
muyangli's avatar
muyangli committed
77
        .def("gemm_awq", nunchaku::ops::gemm_awq)
muyangli's avatar
muyangli committed
78
79
80
81
82
83
84
85
86
87
        .def("gemv_awq", nunchaku::ops::gemv_awq)
    ;

    m.def_submodule("utils")
        .def("set_log_level", [](const std::string &level) {
            spdlog::set_level(spdlog::level::from_str(level));
        })
        .def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
        .def("trim_memory", nunchaku::utils::trim_memory)
    ;
Zhekai Zhang's avatar
Zhekai Zhang committed
88
}