api.cpp 1.06 KB
Newer Older
1
2
3
4
5
#include <pybind11/pybind11.h>

#include "sparse_fwd.h"
#include "sparse_decode.h"
#include "dense_decode.h"
zhanghj2's avatar
zhanghj2 committed
6
#include "dense_decode_qkvfp8.h"
zhanghj2's avatar
zhanghj2 committed
7
#include "dense_decode_kvfp8.h"
zhanghj2's avatar
zhanghj2 committed
8
#include "../extension/flash_api.h"
9
10
11
12
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.doc() = "FlashMLA";
    m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
    m.def("dense_decode_fwd", &dense_attn_decode_interface);
zhanghj2's avatar
zhanghj2 committed
13
    m.def("dense_decode_fwd_qkvfp8", &dense_attn_decode_qkvfp8_interface);
zhanghj2's avatar
zhanghj2 committed
14
    m.def("dense_decode_fwd_kvfp8", &dense_attn_decode_kvfp8_interface);
15
    m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface);
zhanghj2's avatar
zhanghj2 committed
16
17
18
19
20
21
22

    m.def("get_mla_decoding_metadata_dense_fp8", &get_mla_decoding_metadata_dense_fp8);
    m.def("fwd_kvcache_quantization_mla", &mha_fwd_kvcache_quantization_mla);
    m.def("fwd_kvcache_quantization_q_nope_pe_mla", &mha_fwd_kvcache_quantization_q_nope_pe_mla);
    m.def("fwd_kvcache_mla_nope_pe", &mha_fwd_kvcache_mla_nope_pe);
    m.def("fwd_kvcache_mla_fp8", &mha_fwd_kvcache_mla_fp8);
    m.def("fwd_kvcache_mla_fp8_with_cat", &mha_fwd_kvcache_mla_fp8_with_cat);
23
}