template __global__ void flash_permute_sbhd2bhsd(void* output, void* input, int32_t seqlen, int real_headdim); template __global__ void flash_permute_bhsd2sbhd(void* output, void* input, int32_t seqlen, int real_headdim); template __global__ void flash_permute_bshd2bhsd(void* output, void* input, int32_t seqlen, int32_t num_heads); template __global__ void flash_permute_bhsd2bshd(void* output, void* input, int32_t seqlen, int32_t num_heads);