#include #include "sparse_fwd.h" #include "sparse_decode.h" #include "dense_decode.h" #include "dense_decode_qkvfp8.h" #include "dense_decode_kvfp8.h" #include "../extension/flash_api.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashMLA"; m.def("sparse_decode_fwd", &sparse_attn_decode_interface); m.def("dense_decode_fwd", &dense_attn_decode_interface); m.def("dense_decode_fwd_qkvfp8", &dense_attn_decode_qkvfp8_interface); m.def("dense_decode_fwd_kvfp8", &dense_attn_decode_kvfp8_interface); m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface); m.def("get_mla_decoding_metadata_dense_fp8", &get_mla_decoding_metadata_dense_fp8); m.def("fwd_kvcache_quantization_mla", &mha_fwd_kvcache_quantization_mla); m.def("fwd_kvcache_quantization_q_nope_pe_mla", &mha_fwd_kvcache_quantization_q_nope_pe_mla); m.def("fwd_kvcache_mla_nope_pe", &mha_fwd_kvcache_mla_nope_pe); m.def("fwd_kvcache_mla_fp8", &mha_fwd_kvcache_mla_fp8); m.def("fwd_kvcache_mla_fp8_with_cat", &mha_fwd_kvcache_mla_fp8_with_cat); }