api.cpp 479 Bytes
Newer Older
1
2
3
4
5
#include <pybind11/pybind11.h>

#include "sparse_fwd.h"
#include "sparse_decode.h"
#include "dense_decode.h"
zhanghj2's avatar
zhanghj2 committed
6
#include "dense_decode_qkvfp8.h"
zhanghj2's avatar
zhanghj2 committed
7

8
9
10
11
12

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.doc() = "FlashMLA";
    m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
    m.def("dense_decode_fwd", &dense_attn_decode_interface);
zhanghj2's avatar
zhanghj2 committed
13
    m.def("dense_decode_fwd_qkvfp8", &dense_attn_decode_qkvfp8_interface);
14
15
    m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface);
}