#include #include #include #include #include #include #ifdef TK_COMPILE_ATTN extern torch::Tensor sta_forward( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor o, int kernel_t_size, int kernel_w_size, int kernel_h_size, int text_length, bool process_text, bool has_text, int kernel_aspect_ratio_flag ); #endif PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "Sliding Block Attention Kernels"; // optional module docstring #ifdef TK_COMPILE_ATTN m.def("sta_fwd", torch::wrap_pybind_function(sta_forward), "sliding tile attention, assuming tile size is (6,8,8)"); #endif }