#include void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fwd", &FMHACutlassSM100FwdRun); m.def("bwd", &FMHACutlassSM100BwdRun); }