Commit 34e67b1e authored by zhangshao's avatar zhangshao
Browse files

first commit

parents
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t, false>(params, stream);
}
This source diff could not be displayed because it is too large. You can view the blob instead.
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "static_switch.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
#if defined(DCU_ASM)
#define ARCH_SUPPORTS_FLASH
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
static const bool prefetch = get_env_("FLASH_ATTENTION_FWD_PREFETCH");
static const bool force_blockm128 = get_env_("FLASH_ATTENTION_force_blockm128");
static const bool force_blockm64 = get_env_("FLASH_ATTENTION_force_blockm64");
static const bool prefix_cache_force_use_128= get_env_("FLASH_ATTENTION_prefix_cache_force_use_128");
static const bool debug_env = get_env_("FLASH_ATTENTION_debug_env");
static const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__launch_bounds__(Kernel_traits::kNThreads) \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel_16x64, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_attn_16x64<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel_16x64_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_attn_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel_16x64_prefetch_fp8, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_attn_16x64_prefetch_fp8<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel_16x64_prefetch_padding_mask, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_attn_16x64_prefetch_padding_mask<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch_fp8, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv_16x64_vllm_kvcache_prefetch_fp8<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch_kv_fp8, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv_16x64_vllm_kvcache_prefetch_kv_fp8<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel_16x64_vllm_kvcache_gfx928, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv_16x64_vllm_kvcache_gfx928<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_unified_kernel_16x64_prefetch, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Is_need_balance, bool Use_alibi_sqrt, bool Use_qq_bias, bool Use_mm_prefix) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_unified_16x64_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table, Is_need_balance, Use_alibi_sqrt, Use_qq_bias, Use_mm_prefix>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
static_assert(Log_max_splits >= 1);
flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(Kernel_traits::Share_Q_K_smem, "Share_Q_K_smem must be true");
constexpr size_t smem_size = Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;//2
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd_16x64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;//2
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
auto kernel = &flash_fwd_kernel_16x64<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd_16x64_prefetch(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(Kernel_traits::Share_Q_K_smem, "Share_Q_K_smem must be true");
constexpr size_t smem_size = Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;//2
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
// constexpr static bool IsEvenKConst = true;
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
// constexpr static bool Is_local = false;
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// constexpr static bool ReturnSoftmaxConst = true;
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// constexpr static bool Has_alibi = false;
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// constexpr static bool Is_softcap = false;
auto kernel = &flash_fwd_kernel_16x64_prefetch<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd_16x64_prefetch_fp8(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(Kernel_traits::Share_Q_K_smem, "Share_Q_K_smem must be true");
constexpr size_t smem_size = Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;//2
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
// printf("%d,%d\n",Is_causal, is_even_MN);
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
auto kernel = &flash_fwd_kernel_16x64_prefetch_fp8<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd_16x64_prefetch_padding_mask(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(Kernel_traits::Share_Q_K_smem, "Share_Q_K_smem must be true");
constexpr size_t smem_size = Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;//2
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
// BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
constexpr static bool IsEvenMNConst = false;
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
auto kernel = &flash_fwd_kernel_16x64_prefetch_padding_mask<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
// });
}
template<typename Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(Kernel_traits::Is_Q_in_regs, "SplitKV implementation must support Is_Q_in_regs");
static_assert(Kernel_traits::Share_Q_K_smem, "SplitKV implementation must support Share_Q_K_smem");
// params.num_splits大于1的时候,输出值是float类型,是大于Q的。这里改动的本质原因是q与kv共享lds导致的
const size_t smem_size = params.num_splits > 1 ? std::max(Kernel_traits::kSmemQSize * 2, Kernel_traits::kSmemSize) : Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
});
// printf(" run_flash_splitkv_fwd params.num_splits = %d\n", params.num_splits);
if (params.num_splits > 1) {
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 32 : (Kernel_traits::kHeadDim % 64 == 0 ? 32 : 32);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
}
template<typename Kernel_traits, typename Combine_Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params, cudaStream_t stream) {
// params.num_splits大于1的时候,输出值是float类型,是大于Q的。这里改动的本质原因是q与kv共享lds导致的
params.num_splits=1;
const size_t smem_size = params.num_splits > 1 ? std::max(Kernel_traits::kSmemOSize * 2, Kernel_traits::kSmemSize) : Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;//2
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = true;
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
constexpr static bool IsEvenKConst = true;
// constexpr static bool Is_local = false;
constexpr static bool Is_softcap = false;
constexpr static bool Has_block_table = true;
constexpr static bool Append_KV = false;
constexpr static bool Split = false;
auto kernel = &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
});
});
});
});
}
template<typename Kernel_traits, typename Combine_Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_kv_fp8(Flash_fwd_params &params, cudaStream_t stream) {
// params.num_splits大于1的时候,输出值是float类型,是大于Q的。这里改动的本质原因是q与kv共享lds导致的
params.num_splits=1;
const size_t smem_size = params.num_splits > 1 ? std::max(Kernel_traits::kSmemOSize * 2, Kernel_traits::kSmemSize) : Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;//2
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = true;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
constexpr static bool IsEvenKConst = true;
// constexpr static bool Is_local = false;
constexpr static bool Is_softcap = false;
constexpr static bool Has_block_table = true;
constexpr static bool Append_KV = false;
constexpr static bool Split = false;
auto kernel = &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch_kv_fp8<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
}
template<typename Kernel_traits, typename Combine_Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8(Flash_fwd_params &params, cudaStream_t stream) {
// params.num_splits大于1的时候,输出值是float类型,是大于Q的。这里改动的本质原因是q与kv共享lds导致的
params.num_splits=1;
const size_t smem_size = params.num_splits > 1 ? std::max(Kernel_traits::kSmemOSize * 2, Kernel_traits::kSmemSize) : Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;//2
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = true;
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
constexpr static bool IsEvenKConst = true;
// constexpr static bool Is_local = false;
constexpr static bool Is_softcap = false;
constexpr static bool Has_block_table = true;
constexpr static bool Append_KV = false;
constexpr static bool Split = false;
// constexpr static bool Has_alibi = false;
auto kernel = &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch_fp8<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
}
template<typename Kernel_traits, typename Combine_Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd_16x64_unified_prefetch(Flash_fwd_params &params, cudaStream_t stream) {
// params.num_splits大于1的时候,输出值是float类型,是大于Q的。这里改动的本质原因是q与kv共享lds导致的
const size_t smem_size = params.num_splits > 1 ? std::max(Kernel_traits::kSmemOSize * 2, Kernel_traits::kSmemSize) : Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);
#ifdef NO_CAUSAL_OPT
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
#else
const int non_causal_num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int grid_y = params.num_splits > 1 ? params.num_splits : params.b;
const int grid_z = params.num_splits > 1 ? params.b * params.h : params.h;
const bool need_balance = Is_causal && (non_causal_num_m_block * grid_y * grid_z > 80);
const int num_m_block = need_balance ? (non_causal_num_m_block + 1 ) >> 1 :
non_causal_num_m_block;
#endif
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
// BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
BOOL_SWITCH(need_balance, Is_need_balance, [&] {
BOOL_SWITCH(params.use_alibi_sqrt, Use_alibi_sqrt, [&] {
BOOL_SWITCH(params.qq_bias_ptr != nullptr, Use_qq_bias, [&] {
BOOL_SWITCH(params.mm_prefix_range_ptr != nullptr, Use_mm_prefix, [&] {
constexpr static bool Has_block_table = true;
constexpr static bool Append_KV = false;
auto kernel = &flash_fwd_unified_kernel_16x64_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, Is_need_balance, Use_alibi_sqrt, Use_qq_bias, Use_mm_prefix>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
// });
});
});
});
});
// printf(" run_flash_splitkv_fwd params.num_splits = %d\n", params.num_splits);
if (params.num_splits > 1) {
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr static int kBlockM = Combine_Kernel_traits::kHeadDim % 128 == 0 ? 32 : (Kernel_traits::kHeadDim % 64 == 0 ? 32 : 32);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
params.d = params.d_value;
params.d_rounded = params.d_value_rounded;
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Combine_Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Combine_Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Combine_Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Combine_Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Combine_Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Combine_Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Combine_Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
}
template<typename Kernel_traits, typename Combine_Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd_16x64_vllm_kvcache_gfx928(Flash_fwd_params &params, cudaStream_t stream) {
params.num_splits=1;
const size_t smem_size = params.num_splits > 1 ? std::max(Kernel_traits::kSmemOSize * 2, Kernel_traits::kSmemSize) : Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;//2
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = true;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
BOOL_SWITCH(params.block_table != nullptr, Has_block_table, [&] {
constexpr static bool IsEvenKConst = true;
// constexpr static bool Is_local = false;
constexpr static bool Is_softcap = false;
constexpr static bool Has_block_table = true;
constexpr static bool Append_KV = false;
constexpr static bool Split = false;
auto kernel = &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_gfx928<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
});
});
}
template<typename T, int Headdim, bool Is_causal>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int kBlockM = 64; // Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr static int kBlockN = Headdim <= 128 ? 64 : (Headdim % 64 == 0 ? 32 : 64);
int mblocks = (params.seqlen_q + 64 - 1) / 64;
bool is_small = (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count);
if (params.is_vllm_kvcache) {
if constexpr(Headdim == 64) {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<64, kBlockM, kBlockN, 4, false, false, T, 64>;
if (is_small) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim64<64, 64, 64, 4, T, 1, 64>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim64<64, 128, 64, 4, T, 1, 64>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
else {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<64, kBlockM, kBlockN, 4, false, false, T, 64>;
if (is_small) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_vllm_kvcache_traits<64, 64, 64, 4, T, 3, 64>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_gfx928<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_vllm_kvcache_traits<64, 128, 64, 4, T, 3, 64>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_gfx928<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
}else if constexpr(Headdim == 128) {
if (get_device_name() == "gfx936"||get_device_name() == "gfx938") {
assert(params.knew_ptr == nullptr && params.block_table != nullptr);
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, T, 128>;
if (is_small) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits<128, 64, 64, 4, T, 3, 128>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits<128, 128, 64, 4, T, 3, 128>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
else {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, T, 128>;
if (is_small) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_vllm_kvcache_traits<128, 64, 64, 4, T, 3, 128>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_gfx928<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_vllm_kvcache_traits<128, 128, 64, 4, T, 3, 128>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_gfx928<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
}else if constexpr(Headdim == 192) {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<192, kBlockM, kBlockN, 4, false, false, T, 192>;
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim192<192, 64, 64, 4, T, 3, 192>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}else if constexpr(Headdim == 256) {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, T, 256>;
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim256<256, 64, 64, 4, T, 3, 256>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
}else{
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, true, T>, Is_causal>(params, stream);
}
}
template<typename Q,typename KV, int Headdim, bool Is_causal>
void run_mha_fwd_splitkv_dispatch_kv_fp8(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int kBlockM = 64; // Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr static int kBlockN = Headdim <= 128 ? 64 : (Headdim % 64 == 0 ? 32 : 64);
int mblocks = (params.seqlen_q + 64 - 1) / 64;
bool is_small = (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count);
// printf("kBlockM = %d, kBlockN = %d", kBlockM, kBlockN);
#ifndef FLASHATTENTION_DISABLE_SPLITKV
if constexpr(Headdim == 64) {
if (get_device_name() == "gfx936") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<64, kBlockM, kBlockN, 4, false, false, Q, 64>;
if (is_small) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_kv_fp8_dim64<64, 64, 64, 4, Q, KV, 1, 64>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_kv_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_kv_fp8_dim64<64, 128, 64, 4, Q, KV, 1, 64>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_kv_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
}
}else if constexpr(Headdim == 128) {
if (get_device_name() == "gfx936") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, Q, 128>;
if (is_small) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_kv_fp8<128, 64, 64, 4, Q, KV, 1, 128>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_kv_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_kv_fp8<128, 128, 64, 4, Q, KV, 1, 128>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_kv_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
}
}else if constexpr(Headdim == 256) {
if (get_device_name() == "gfx936") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, Q, 256>;
if (is_small) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_kv_fp8_dim256<256, 64, 64, 4, Q, KV, 1, 256>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_kv_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_kv_fp8_dim256<256, 64, 64, 4, Q, KV, 1, 256>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_kv_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
}
}
#endif
}
template<typename T,typename TO, int Headdim, bool Is_causal>
void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int kBlockM = 64; // Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr static int kBlockN = Headdim <= 128 ? 64 : (Headdim % 64 == 0 ? 32 : 64);
// printf("kBlockM = %d, kBlockN = %d", kBlockM, kBlockN);
#ifndef FLASHATTENTION_DISABLE_SPLITKV
if constexpr(Headdim == 64) {
if (get_device_name() == "gfx938") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<64, kBlockM, kBlockN, 4, false, false, TO, 64>;
if (params.seqlen_q < 64) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8_dim64<64, 64, 64, 4, T, TO, 1, 64>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8_dim64<64, 128, 64, 4, T, TO, 1, 64>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
}
}else if constexpr(Headdim == 128) {
if (get_device_name() == "gfx938") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, TO, 128>;
if (params.seqlen_q < 64) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8<128, 64, 64, 4, T, TO, 1, 128>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8<128, 128, 64, 4, T, TO, 1, 128>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
}
}else if constexpr(Headdim == 192) {
if (get_device_name() == "gfx938") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<192, kBlockM, kBlockN, 4, false, false, TO, 192>;
if (params.seqlen_q < 64) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8_dim192<192, 64, 64, 4, T, TO, 1, 192>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8_dim192<192, 128, 64, 4, T, TO, 1, 192>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
}
}else if constexpr(Headdim == 256) {
if (get_device_name() == "gfx938") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, TO, 256>;
if (params.seqlen_q < 64) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8_dim256<256, 64, 64, 4, T, TO, 1, 256>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_fp8_dim256<256, 128, 64, 4, T, TO, 1, 256>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}
}
}
#endif
}
template<typename T, int Headdim, bool Is_causal>
void run_mha_fwd_unified_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int kBlockM = 64;
constexpr static int kBlockN = Headdim <= 128 ? 64 : (Headdim % 64 == 0 ? 32 : 64);
if constexpr(Headdim == 256) {
if (get_device_name() == "gfx938" || get_device_name() == "gfx936") {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_unified_traits_dim256<256, 64, 64, 4, T, 3, 256>;
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, T, 256>;
run_flash_splitkv_fwd_16x64_unified_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}else{
assert(false && "unified attn only supported headdim=256");
}
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
#if 1
constexpr static int Headdim = 32;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
});
#endif
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
#if 0
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 256, 64, 4, false, /*Share_Q_K_smem_=*/true, T>, Is_dropout, Is_causal>(params, stream);
#else
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
int mblocks = (params.seqlen_q + 64 - 1) / 64;
if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 64, 64, 4, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 256, 64, 4, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd_16x64<Flash_fwd_kernel_16x64_traits<Headdim, 256, 64, 4, /*Is_Q_use_smem_=*/false, /*Share_K_V_smem_=*/false, T>, Is_dropout, Is_causal>(params, stream);
}
#endif
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_padding_mask_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
run_flash_fwd_16x64_prefetch_padding_mask<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 128, 64, 4, T>, Is_dropout, Is_causal>(params, stream);
}
else {
run_flash_fwd_16x64<Flash_fwd_kernel_16x64_traits<Headdim, 128, 64, 4, /*Is_Q_use_smem_=*/false, /*Share_K_V_smem_=*/false, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
int mblocks = (params.seqlen_q + 64 - 1) / 64;
if(params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count){
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim96<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
}
else{
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim96<Headdim, 128, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T, bool Is_causal, bool Is_skip_softmax = false>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
// printf("run_mha_fwd_hdim128\n");
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
int mblocks = (params.seqlen_q + 64 - 1) / 64;
if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 64, 64, 4, T, 3, Is_skip_softmax>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 128, 64, 4, T, 3, Is_skip_softmax>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd_16x64<Flash_fwd_kernel_16x64_traits<Headdim, 128, 64, 4, /*Is_Q_use_smem_=*/true, /*Share_K_V_smem_=*/false, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
// printf("run_mha_fwd_hdim128\n");
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx938") {
int mblocks = (params.seqlen_q + 64 - 1) / 64;
if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) {
run_flash_fwd_16x64_prefetch_fp8<Flash_fwd_kernel_16x64_prefetch_traits_fp8<Headdim, 64, 64, 4, T,T_out, 3>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd_16x64_prefetch_fp8<Flash_fwd_kernel_16x64_prefetch_traits_fp8<Headdim, 128, 64, 4, T,T_out, 3>, Is_dropout, Is_causal>(params, stream);
}
} else {
printf("this device is not supoort fp8");
}
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim192_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_mla_traits</*Headdim*/192, 128, 64, 4, T, 3, /*HeaddimV*/128>, Is_dropout, Is_causal>(params, stream);
}
else {
run_flash_fwd_16x64<Flash_fwd_kernel_16x64_traits_MLA</*Headdim*/192, 128, 64, 4, /*Is_Q_use_smem_=*/false, /*Share_K_V_smem_=*/true, T, 128>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim192_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx938") {
run_flash_fwd_16x64_prefetch_fp8<Flash_fwd_kernel_16x64_prefetch_mla_traits_fp8</*Headdim*/192, 128, 64, 4, T,T_out, 3, /*HeaddimV*/128>, Is_dropout, Is_causal>(params, stream);
}
else {
printf("this device is not supoort fp8");
}
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// constexpr static int Is_dropout = false;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_hdim512(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 512;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// constexpr static int Is_dropout = false;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_padding_mask_<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_padding_mask_hdim64<cutlass::half_t, false>(params, stream);
}
template<>
void run_mha_fwd_padding_mask_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_padding_mask_hdim64<cutlass::bfloat16_t, false>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// #include "namespace_config.h"
#include "flash_fwd_sparse_launch_template.h"
// namespace FLASH_NAMESPACE {
template<>
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::bfloat16_t, true>(params, stream);
}
// } // namespace FLASH_NAMESPACE
\ No newline at end of file
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// #include "namespace_config.h"
#include "flash_fwd_sparse_launch_template.h"
// namespace FLASH_NAMESPACE {
template<>
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::bfloat16_t, false>(params, stream);
}
template<>
void run_mha_fwd_sparse_sla_<cutlass::bfloat16_t, 128>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
run_mha_fwd_sparse_sla_hdim128<cutlass::bfloat16_t>(params, stream);
}
template<>
void run_mha_fwd_sparse_sla_fp8_<cutlass::float_e5m2_t, 128>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
run_mha_fwd_sparse_sla_hdim128_fp8<cutlass::float_e5m2_t>(params, stream);
}
// } // namespace FLASH_NAMESPACE
\ No newline at end of file
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// #include "namespace_config.h"
#include "flash_fwd_sparse_launch_template.h"
// namespace FLASH_NAMESPACE {
template<>
void run_mha_fwd_sparse_<cutlass::half_t, 128, true>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::half_t, true>(params, stream);
}
// } // namespace FLASH_NAMESPACE
\ No newline at end of file
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// #include "namespace_config.h"
#include "flash_fwd_sparse_launch_template.h"
// namespace FLASH_NAMESPACE {
template<>
void run_mha_fwd_sparse_<cutlass::half_t, 128, false>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
run_mha_fwd_sparse_hdim128<cutlass::half_t, false>(params, stream);
}
template<>
void run_mha_fwd_sparse_sla_<cutlass::half_t, 128>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
run_mha_fwd_sparse_sla_hdim128<cutlass::half_t>(params, stream);
}
template<>
void run_mha_fwd_sparse_sla_fp8_<cutlass::float_e4m3_t, 128>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
run_mha_fwd_sparse_sla_hdim128_fp8<cutlass::float_e4m3_t>(params, stream);
}
// } // namespace FLASH_NAMESPACE
\ No newline at end of file
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// #include "namespace_config.h"
#include "flash_fwd_sparse_launch_template.h"
// namespace FLASH_NAMESPACE {
// template<>
// void run_mha_fwd_sparse_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
// run_mha_fwd_sparse_hdim64<cutlass::bfloat16_t, false>(params, stream);
// }
template<>
void run_mha_fwd_sparse_sla_<cutlass::bfloat16_t, 64>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
run_mha_fwd_sparse_sla_hdim64<cutlass::bfloat16_t>(params, stream);
}
// template<>
// void run_mha_fwd_sparse_sla_fp8_<cutlass::float_e5m2_t, 64>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
// run_mha_fwd_sparse_sla_hdim64_fp8<cutlass::float_e5m2_t>(params, stream);
// }
// } // namespace FLASH_NAMESPACE
\ No newline at end of file
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// #include "namespace_config.h"
#include "flash_fwd_sparse_launch_template.h"
// namespace FLASH_NAMESPACE {
// template<>
// void run_mha_fwd_sparse_<cutlass::half_t, 64, false>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
// run_mha_fwd_sparse_hdim64<cutlass::half_t, false>(params, stream);
// }
template<>
void run_mha_fwd_sparse_sla_<cutlass::half_t, 64>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
run_mha_fwd_sparse_sla_hdim64<cutlass::half_t>(params, stream);
}
// template<>
// void run_mha_fwd_sparse_sla_fp8_<cutlass::float_e4m3_t, 64>(Flash_fwd_params_sparse &params, cudaStream_t stream) {
// run_mha_fwd_sparse_sla_hdim64_fp8<cutlass::float_e4m3_t>(params, stream);
// }
// } // namespace FLASH_NAMESPACE
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2024, PAI, Alibaba Cloud.
******************************************************************************/
#pragma once
#include "flash_fwd_kernel.h"
namespace flash {
using namespace cute;
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
inline __device__ void sparse_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
#define S_WAITCNT asm volatile("s_waitcnt vmcnt(3) \n s_barrier")
#define S_BARRIER asm volatile("s_barrier")
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
auto seed_offset = at::cuda::philox::unpack(params.philox_args);
flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);
// Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
// exit early and no one saves the rng states.
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
params.rng_state[0] = std::get<0>(seed_offset);
params.rng_state[1] = std::get<1>(seed_offset);
}
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal || Is_local) {
n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
}
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)
+ binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)
+ binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.k_row_stride, params.k_head_stride, _1{}));
Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(0, 0)); // (kBlockN, kHeadDim, nblocksN)
Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)
+ binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.v_row_stride, params.v_head_stride, _1{}));
Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(0, 0)); // (kBlockN, kHeadDim, nblocksN)
const auto gK_data = gK.data();
const auto gV_data = gV.data();
const index_t row_offset_k_token =
binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v_token =
binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor gKToken = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k_token),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(_0{}, _1{}));
Tensor gVToken = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v_token),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(_0{}, _1{}));
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k_rounded, _1{}));
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutK{});
Tensor sV = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtSplit = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransSplit{});
typename Kernel_traits::TiledMma16x64 tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
typename Kernel_traits::TiledMma16x32 tiled_mma_for_gemm1;
auto thr_mma_for_gemm1 = tiled_mma_for_gemm1.get_thread_slice(tidx);
Tensor tGrQ = thr_mma.partition_fragment_A(gQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(gK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma_for_gemm1.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
Tensor tSgS = thr_mma.partition_C(gP);
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
// auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_B128, Element>{}, tiled_mma);
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16, Element>{}, tiled_mma_for_gemm1);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt8x64 = smem_thr_copy_V.partition_S(sVtSplit);
Tensor tOsVt = make_tensor(tOsVt8x64.data(), convert_layout_B_rowcol_<_16x128, kHeadDimV/32>(tOsVt8x64.layout()));
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
// Set predicates for k bounds
if constexpr (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
}
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_Q, tSgQ, tGrQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM);
int n_block = n_block_max - 1;
int num_blks = params.block_count[(bidb * params.h + bidh) * params.NUM_ROWS + m_block];
auto* blks_ptr = params.block_offset + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_S;
Tensor acc_o = partition_fragment_C(tiled_mma_for_gemm1, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // MMA, MMA_M, MMA_K
clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax;
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
int num_cols = params.column_count[(bidb * params.h + bidh) * params.NUM_ROWS + m_block];
int num_cols_block = (num_cols + kBlockN - 1)/ kBlockN;
constexpr int kStages = Kernel_traits::kStages;
constexpr int k0_loops = size<2>(tSsK);
constexpr int k1_loops = size<2>(tOsVt);
if (num_blks <= 0 && num_cols_block <= 0) {
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<Element>(shape(tOgO));
clear(tOrO);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgO); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
}
return;
}
if (num_blks > 0) {
int block_index = num_blks - 1;
int actual_block = blks_ptr[block_index];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
#pragma unroll
for (int i = 0; i < kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, i, params.k_row_stride, params.d, binfo.actual_seqlen_k - actual_block);
}
for (int masking_step = 0; masking_step < n_masking_steps && block_index >= 0; ++masking_step, --block_index) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
#pragma unroll
for (int i = 0; i < k0_loops - kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, kStages + i, params.k_row_stride, params.d, binfo.actual_seqlen_k - actual_block);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, i);
S_BARRIER;
}
#pragma unroll
for (int i = 0; i < kStages; ++i) { // tail kStages
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gV, sV, i, params.v_row_stride, params.d, binfo.actual_seqlen_k - actual_block);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k0_loops - kStages + i);
S_BARRIER;
}
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
}
{
const int wave_id = (tidx / 64);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx % 16) + (wave_id_to_row_block_id * 16);
const int row_idx_offset_ = m_block * kBlockM + row_idx_offset_in_block;
mask.template apply_mask_continuous<Is_causal, Is_even_MN>(
acc_s, actual_block, row_idx_offset_, (kNWarps << 4)
);
}
// Sparse Warp Online Softmax: compute max_diff for dynamic PV skip
float max_diff = -INFINITY;
bool skip_pv = false;
bool is_first_block = (masking_step == 0);
if (is_first_block) {
max_diff = softmax.template softmax_rescale_o_with_diff</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
} else {
max_diff = softmax.template softmax_rescale_o_with_diff</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
// Check if we can skip P@V computation based on dynamic threshold
// Skip when max_diff + pv_threshold <= 0 (current block's contribution is negligible)
if (params.enable_dynamic_skip) {
// Reduce max_diff across the warp to get the maximum value
MaxOp<float> max_op;
float warp_max_diff = Allreduce<64>::run(max_diff, max_op);
// All threads in the warp must agree on skip decision
skip_pv = (warp_max_diff + params.pv_threshold <= 0.0f);
}
}
Tensor rP = flash::convert_type<Element>(acc_s);
{ // dropout
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = m_block * kBlockM + row_idx_offset_in_block;
const int block_row_idx = row_idx_offset_;
const int block_col_idx = n_block * (kBlockN);
if constexpr (Return_softmax) {
Tensor rP_drop = make_fragment_like(rP);
cute::copy(rP, rP_drop);
dropout.template apply_dropout_continuous</*encode_dropout_in_sign_bit=*/true>(
rP_drop, block_row_idx, block_col_idx, kNWarps * 16
);
Tensor rP_drop_back = make_tensor(rP_drop.data(), convert_layout_acc_back(rP_drop.layout()));
cute::copy(rP_drop_back, tSgS);
tSgS.data() = tSgS.data() + (-kBlockN);
}
if constexpr (Is_dropout) {
dropout.apply_dropout_continuous(rP, block_row_idx, block_col_idx, kNWarps * 16);
}
}
// Sparse Warp Online Softmax: Skip P@V if contribution is negligible
if (!skip_pv) {
// Accumulate softmax sum (must be done AFTER confirming we're not skipping)
// This aligns with SpargeAttn's accumulate_d() pattern
if (is_first_block) {
softmax.template accumulate_softmax_sum</*Is_first=*/true>(acc_s);
} else {
softmax.template accumulate_softmax_sum</*Is_first=*/false>(acc_s);
}
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gV, sV, kStages + 0, params.v_row_stride, params.d, binfo.actual_seqlen_k - actual_block);
S_WAITCNT;
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
#if 1
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 2
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 3
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
S_BARRIER;
#endif
// if (thread0())
// {
// // asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
// printf("block_index = %d actual_block = %d \n \n", block_index, actual_block);
// for (int i = 0; i < 64; i++)
// {
// for (int j = 0; j < 128; j++) {
// printf(" %.2f ", float(sV(i, j)));
// }
// printf("\n");
// }
// }
}
// BUGFIX: Prefetch next block data regardless of skip decision
// This ensures pipeline is always filled for next iteration
if (block_index > 0) {
actual_block = blks_ptr[block_index - 1];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
lds_direct_copy<Is_even_K>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 1, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 2, params.k_row_stride, params.d);
}
}
#if 1
for (; block_index >= 0; --block_index) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
#pragma unroll
for (int i = 0; i < k0_loops - kStages; ++i) {
lds_direct_copy<Is_even_K, true>(gK, sK, kStages + i, params.k_row_stride, params.d);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, i);
S_BARRIER;
}
#pragma unroll
for (int i = 0; i < kStages; ++i) { // tail kStages
lds_direct_copy<Is_even_K, true, _16x128>(gV, sV, i, params.v_row_stride, params.d);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k0_loops - kStages + i);
S_BARRIER;
}
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
}
// {
// const int wave_id = (tidx / 64);
// const int wave_id_to_row_block_id = wave_id;
// const int warp_row_stride = 16;
// const int row_idx_offset_in_block = (tidx % 16) + (wave_id_to_row_block_id * 16);
// const int row_idx_offset_ = m_block * kBlockM + row_idx_offset_in_block;
// mask.template apply_mask_continuous<false>(
// acc_s, actual_block, row_idx_offset_, (kNWarps << 4)
// );
// }
// Sparse Warp Online Softmax: compute max_diff for dynamic PV skip
float max_diff = softmax.template softmax_rescale_o_with_diff</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
bool skip_pv = false;
if (params.enable_dynamic_skip) {
// Reduce max_diff across the warp to get the maximum value
MaxOp<float> max_op;
float warp_max_diff = Allreduce<64>::run(max_diff, max_op);
// All threads in the warp must agree on skip decision
skip_pv = (warp_max_diff + params.pv_threshold <= 0.0f);
}
Tensor rP = flash::convert_type<Element>(acc_s);
{ // dropout
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = m_block * kBlockM + row_idx_offset_in_block;
const int block_row_idx = row_idx_offset_;
const int block_col_idx = n_block * (kBlockN);
if constexpr (Return_softmax) {
Tensor rP_drop = make_fragment_like(rP);
cute::copy(rP, rP_drop);
dropout.template apply_dropout_continuous</*encode_dropout_in_sign_bit=*/true>(
rP_drop, block_row_idx, block_col_idx, kNWarps * 16
);
Tensor rP_drop_back = make_tensor(rP_drop.data(), convert_layout_acc_back(rP_drop.layout()));
cute::copy(rP_drop_back, tSgS);
tSgS.data() = tSgS.data() + (-kBlockN);
}
if constexpr (Is_dropout) {
dropout.apply_dropout_continuous(rP, block_row_idx, block_col_idx, kNWarps * 16);
}
}
// Sparse Warp Online Softmax: Skip P@V if contribution is negligible
if (!skip_pv) {
// Accumulate softmax sum (must be done AFTER confirming we're not skipping)
softmax.template accumulate_softmax_sum</*Is_first=*/false>(acc_s);
lds_direct_copy<Is_even_K, true, _16x128>(gV, sV, kStages + 0, params.v_row_stride, params.d);
S_WAITCNT;
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
#if 1
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 2
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 3
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
S_BARRIER;
#endif
}
// BUGFIX: Prefetch next block data regardless of skip decision
// This ensures pipeline is always filled for next iteration
#if 1
if (block_index > 0) {
const int actual_block = blks_ptr[block_index - 1];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
lds_direct_copy<Is_even_K>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 1, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 2, params.k_row_stride, params.d);
}
#endif
}
#endif
}
#if 1
if (num_cols > 0) {
const int* cols_ptr = params.column_index + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_V;
int row = tidx % 16;
int lane = tidx % 64;
// int col = lane / 16;
int k_row_offset = row * 4 + tidx / 64 < num_cols ? cols_ptr[row * 4 + tidx / 64] : binfo.actual_seqlen_k;
int v_row_offset_0 = (tidx % 64) / 4 < num_cols ? cols_ptr[(tidx % 64) / 4] : binfo.actual_seqlen_k;
int v_row_offset_1 = (tidx % 64) / 4 + 16 < num_cols ? cols_ptr[(tidx % 64) / 4 + 16] : binfo.actual_seqlen_k;
int v_row_offset_2 = (tidx % 64) / 4 + 32 < num_cols ? cols_ptr[(tidx % 64) / 4 + 2 * 16] : binfo.actual_seqlen_k;
int v_row_offset_3 = (tidx % 64) / 4 + 48 < num_cols ? cols_ptr[(tidx % 64) / 4 + 3 * 16] : binfo.actual_seqlen_k;
#pragma unroll
for (int i = 0; i < kStages; ++i) {
lds_direct_copy_for_vertical_sparse<Is_even_K, Is_even_MN>(gKToken, sK, i, params.k_row_stride, k_row_offset, params.d, binfo.actual_seqlen_k);
}
// if (thread0())
// {
// printf(" num_cols_block = %d n_masking_steps = %d \n ", num_cols_block, n_masking_steps);
// }
// asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
#if 1
for (int n = 0; n < num_cols_block; ++n) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
#pragma unroll
for (int i = 0; i < k0_loops - kStages; ++i) {
lds_direct_copy_for_vertical_sparse<Is_even_K, Is_even_MN>(gKToken, sK, kStages + i, params.k_row_stride, k_row_offset, params.d, binfo.actual_seqlen_k);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, i);
S_BARRIER;
}
{
lds_direct_copy_for_vertical_sparse<Is_even_K, Is_even_MN, _16x128>(gVToken, sV, 0, params.v_row_stride, v_row_offset_0, params.d, binfo.actual_seqlen_k);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k0_loops - kStages + 0);
S_BARRIER;
lds_direct_copy_for_vertical_sparse<Is_even_K, Is_even_MN, _16x128>(gVToken, sV, 1, params.v_row_stride, v_row_offset_1, params.d, binfo.actual_seqlen_k);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k0_loops - kStages + 1);
S_BARRIER;
lds_direct_copy_for_vertical_sparse<Is_even_K, Is_even_MN, _16x128>(gVToken, sV, 2, params.v_row_stride, v_row_offset_2, params.d, binfo.actual_seqlen_k);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k0_loops - kStages + 2);
S_BARRIER;
}
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
}
if (n >= num_cols_block - n_masking_steps) {
Tensor tensor = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int lane_id = threadIdx.x & 63;
const int col_idx_offset = n * kBlockN + ((lane_id >> 4) << 2);
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int row_idx_offset_in_block = (tidx & (16 - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset = m_block * kBlockM + row_idx_offset_in_block;
const int warp_row_stride = kNWarps * 16;
const int max_seqlen_k = binfo.actual_seqlen_k;
const int max_seqlen_q = binfo.actual_seqlen_q;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx = row_idx_offset + mi * warp_row_stride;
const int col_idx_limit = Is_causal ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + (nj << 4);
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
// if (block0() && threadIdx.x < 64) {
// printf(" threadIdx.x = %d num_cols = %d col_idx = %d cols_ptr[col_idx] = %d col_idx_limit = %d\n", threadIdx.x, num_cols, col_idx, cols_ptr[col_idx], col_idx_limit);
// }
if (col_idx >= num_cols || cols_ptr[col_idx] >= col_idx_limit) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
// Sparse Warp Online Softmax: compute max_diff for dynamic PV skip
float max_diff = -INFINITY;
bool skip_pv = false;
bool is_first_block = (num_blks <= 0 && n == 0);
if (is_first_block) {
max_diff = softmax.template softmax_rescale_o_with_diff</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
} else {
max_diff = softmax.template softmax_rescale_o_with_diff</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
// Check if we can skip P@V computation based on dynamic threshold
if (params.enable_dynamic_skip) {
// Reduce max_diff across the warp to get the maximum value
MaxOp<float> max_op;
float warp_max_diff = Allreduce<64>::run(max_diff, max_op);
// All threads in the warp must agree on skip decision
skip_pv = (warp_max_diff + params.pv_threshold <= 0.0f);
}
}
Tensor rP = flash::convert_type<Element>(acc_s);
// if (block(1))
// {
// printf("block1 tidx %d rP %.2f %.2f %.2f %.2f \n", tidx, float(rP(0)), float(rP(1)), float(rP(2)), float(rP(3)));
// }
{ // dropout
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = m_block * kBlockM + row_idx_offset_in_block;
const int block_row_idx = row_idx_offset_;
const int block_col_idx = n_block * (kBlockN);
if constexpr (Return_softmax) {
Tensor rP_drop = make_fragment_like(rP);
cute::copy(rP, rP_drop);
dropout.template apply_dropout_continuous</*encode_dropout_in_sign_bit=*/true>(
rP_drop, block_row_idx, block_col_idx, kNWarps * 16
);
Tensor rP_drop_back = make_tensor(rP_drop.data(), convert_layout_acc_back(rP_drop.layout()));
cute::copy(rP_drop_back, tSgS);
tSgS.data() = tSgS.data() + (-kBlockN);
}
if constexpr (Is_dropout) {
dropout.apply_dropout_continuous(rP, block_row_idx, block_col_idx, kNWarps * 16);
}
}
// Sparse Warp Online Softmax: Skip P@V if contribution is negligible
if (!skip_pv) {
// Accumulate softmax sum (must be done AFTER confirming we're not skipping)
// This aligns with SpargeAttn's accumulate_d() pattern
if (is_first_block) {
softmax.template accumulate_softmax_sum</*Is_first=*/true>(acc_s);
} else {
softmax.template accumulate_softmax_sum</*Is_first=*/false>(acc_s);
}
lds_direct_copy_for_vertical_sparse<Is_even_K, Is_even_MN, _16x128>(gVToken, sV, 3, params.v_row_stride, v_row_offset_3, params.d, binfo.actual_seqlen_k);
S_WAITCNT;
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// if (block0() && threadIdx.x < 64)
// {
// printf("tidx = %d acc_s rp = %.2f %.2f %.2f binfo.actual_seqlen_k = %d v_row_offset_0 = %d v_row_offset_1 = %d v_row_offset_2 = %d v_row_offset_3 = %d\n", threadIdx.x, float(rP(0)), float(tOrVt(0)), acc_o(0), binfo.actual_seqlen_k, v_row_offset_0, v_row_offset_1, v_row_offset_2, v_row_offset_3);
// }
#if 1
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 2
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 3
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
S_BARRIER;
#endif
}
// BUGFIX: Prefetch next block data regardless of skip decision
// This ensures pipeline is always filled for next iteration
#if 1
if (n + 1 < num_cols_block) {
// cols_ptr += kBlockN;
int row = tidx % 16;
// int col = lane / 16;
k_row_offset = (n + 1) * kBlockN + row * 4 + tidx / 64 < num_cols ? cols_ptr[ (n + 1) * kBlockN + row * 4 + tidx / 64] : binfo.actual_seqlen_k;
v_row_offset_0 =(n + 1) * kBlockN + (tidx % 64) / 4 < num_cols ? cols_ptr[(n + 1) * kBlockN + (tidx % 64) / 4] : binfo.actual_seqlen_k;
v_row_offset_1 = (n + 1) * kBlockN + (tidx % 64) / 4 + 16 < num_cols ? cols_ptr[(n + 1) * kBlockN + (tidx % 64) / 4 + 16] : binfo.actual_seqlen_k;
v_row_offset_2 = (n + 1) * kBlockN + (tidx % 64) / 4 + 32 < num_cols ? cols_ptr[(n + 1) * kBlockN + (tidx % 64) / 4 + 2 * 16] : binfo.actual_seqlen_k;
v_row_offset_3 = (n + 1) * kBlockN + (tidx % 64) / 4 + 48 < num_cols ? cols_ptr[(n + 1) * kBlockN + (tidx % 64) / 4 + 3 * 16] : binfo.actual_seqlen_k;
lds_direct_copy_for_vertical_sparse<Is_even_K, Is_even_MN>(gKToken, sK, 0, params.k_row_stride, k_row_offset, params.d, binfo.actual_seqlen_k);
lds_direct_copy_for_vertical_sparse<Is_even_K, Is_even_MN>(gKToken, sK, 1, params.k_row_stride, k_row_offset, params.d, binfo.actual_seqlen_k);
lds_direct_copy_for_vertical_sparse<Is_even_K, Is_even_MN>(gKToken, sK, 2, params.k_row_stride, k_row_offset, params.d, binfo.actual_seqlen_k);
// lds_direct_copy_for_vertical_sparse<Is_even_K, Is_even_MN>(gKToken, sK, 3, params.k_row_stride, cols_ptr, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
}
#endif
}
#endif
}
// if (thread0())
// {
// printf(" acc_o %.2f %.2f \n", float(acc_o(0))), float(acc_o(1));
// }
#endif
Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}
int row, col;
const int warpId = tidx / 64;
const int laneId = tidx % 64;
#pragma unroll
for (int mi = 0; mi < size<1>(acc_o); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_o); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_o); ++ei) {
// wangaq debug
// if(thread(0, 0)) {
// printf("mi:%d ni:%d ei:%d row:%d col:%d acc_o:%.4f\n",
// mi, ni, ei, row, col, acc_o(ei, mi, ni));
// }
if (Is_even_K || col < params.d_value) {
gO(row, col) = flash::convert_type<Element>(acc_o(ei, mi, ni));
}
// else
// gO(row, col) = Element(0.0);
col += 4;
}
}
}
}
}
#if 1
template<typename Kernel_traits, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void sparse_attn_1rowblock_sla(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_block_min = 0;
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)
+ binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)
+ binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.k_row_stride, params.k_head_stride, _1{}));
Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(0, 0)); // (kBlockN, kHeadDim, nblocksN)
Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)
+ binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.v_row_stride, params.v_head_stride, _1{}));
Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(0, 0)); // (kBlockN, kHeadDim, nblocksN)
const auto gK_data = gK.data();
const auto gV_data = gV.data();
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{}, make_stride(params.seqlen_k_rounded, _1{}));
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutK{});
Tensor sV = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtSplit = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransSplit{});
typename Kernel_traits::TiledMma16x64 tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
typename Kernel_traits::TiledMma16x32 tiled_mma_for_gemm1;
auto thr_mma_for_gemm1 = tiled_mma_for_gemm1.get_thread_slice(tidx);
Tensor tGrQ = thr_mma.partition_fragment_A(gQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(gK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma_for_gemm1.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
Tensor tSgS = thr_mma.partition_C(gP);
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
// auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_B128, Element>{}, tiled_mma);
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16, Element>{}, tiled_mma_for_gemm1);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt8x64 = smem_thr_copy_V.partition_S(sVtSplit);
Tensor tOsVt = make_tensor(tOsVt8x64.data(), convert_layout_B_rowcol_<_16x128, kHeadDimV/32>(tOsVt8x64.layout()));
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
// Set predicates for k bounds
if constexpr (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
}
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_Q, tSgQ, tGrQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM);
int n_block = n_block_max - 1;
int num_blks = params.block_count[(bidb * params.h + bidh) * params.NUM_ROWS + m_block];
auto* blks_ptr = params.block_offset + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_S;
Tensor acc_o = partition_fragment_C(tiled_mma_for_gemm1, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // MMA, MMA_M, MMA_K
clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax;
flash::Mask<false, false, false> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, 0.0);
constexpr int n_masking_steps = 1;
constexpr int kStages = Kernel_traits::kStages;
constexpr int k0_loops = size<2>(tSsK);
constexpr int k1_loops = size<2>(tOsVt);
if (num_blks <= 0) {
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<Element>(shape(tOgO));
clear(tOrO);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgO); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
}
return;
}
int block_index = num_blks - 1;
int actual_block = blks_ptr[block_index];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
#pragma unroll
for (int i = 0; i < kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, i, params.k_row_stride, params.d, binfo.actual_seqlen_k - actual_block);
}
for (int masking_step = 0; masking_step < n_masking_steps && block_index >= 0; ++masking_step, --block_index) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
#pragma unroll
for (int i = 0; i < k0_loops - kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, kStages + i, params.k_row_stride, params.d, binfo.actual_seqlen_k - actual_block);
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, i);
s_barrier();
}
#pragma unroll
for (int i = 0; i < kStages; ++i) { // tail kStages
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gV, sV, i, params.v_row_stride, params.d, binfo.actual_seqlen_k - actual_block);
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k0_loops - kStages + i);
s_barrier();
}
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
{
const int wave_id = (tidx / 64);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx % 16) + (wave_id_to_row_block_id * 16);
const int row_idx_offset_ = m_block * kBlockM + row_idx_offset_in_block;
mask.template apply_mask_continuous<false, Is_even_MN>(
acc_s, actual_block, row_idx_offset_, (kNWarps << 4)
);
}
masking_step == 0
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/false>(acc_s, acc_o, params.scale_softmax_log2)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/false>(acc_s, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(acc_s);
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gV, sV, kStages + 0, params.v_row_stride, params.d, binfo.actual_seqlen_k - actual_block);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 2
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 3
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
// if (thread0())
// {
// // asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
// printf("block_index = %d actual_block = %d \n \n", block_index, actual_block);
// for (int i = 0; i < 64; i++)
// {
// for (int j = 0; j < 128; j++) {
// printf(" %.2f ", float(sV(i, j)));
// }
// printf("\n");
// }
// }
if (block_index > 0) {
actual_block = blks_ptr[block_index - 1];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
lds_direct_copy<Is_even_K>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 1, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 2, params.k_row_stride, params.d);
}
}
#if 1
for (; block_index >= 0; --block_index) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
#pragma unroll
for (int i = 0; i < k0_loops - kStages; ++i) {
lds_direct_copy<Is_even_K, true>(gK, sK, kStages + i, params.k_row_stride, params.d);
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, i);
s_barrier();
}
#pragma unroll
for (int i = 0; i < kStages; ++i) { // tail kStages
lds_direct_copy<Is_even_K, true, _16x128>(gV, sV, i, params.v_row_stride, params.d);
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k0_loops - kStages + i);
s_barrier();
}
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/false>(acc_s, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(acc_s);
lds_direct_copy<Is_even_K, true, _16x128>(gV, sV, kStages + 0, params.v_row_stride, params.d);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
#if 1
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 2
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 3
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
if (block_index > 0) {
const int actual_block = blks_ptr[block_index - 1];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
lds_direct_copy<Is_even_K>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 1, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 2, params.k_row_stride, params.d);
}
#endif
}
#endif
Tensor lse = softmax.template normalize_softmax_lse<false>(acc_o, params.scale_softmax, params.rp_dropout);
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}
int row, col;
const int warpId = tidx / 64;
const int laneId = tidx % 64;
#pragma unroll
for (int mi = 0; mi < size<1>(acc_o); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_o); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_o); ++ei) {
// wangaq debug
// if(thread(0, 0)) {
// printf("mi:%d ni:%d ei:%d row:%d col:%d acc_o:%.4f\n",
// mi, ni, ei, row, col, acc_o(ei, mi, ni));
// }
if (Is_even_K || col < params.d_value) {
gO(row, col) = flash::convert_type<Element>(acc_o(ei, mi, ni));
}
// else
// gO(row, col) = Element(0.0);
col += 4;
}
}
}
}
}
#endif
template<typename Kernel_traits, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void sparse_attn_1rowblock_sla_dim64(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
// constexpr int kStages = Kernel_traits::kStages;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_block_min = 0;
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)
+ binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)
+ binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.k_row_stride, params.k_head_stride, _1{}));
Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(0, 0)); // (kBlockN, kHeadDim, nblocksN)
Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)
+ binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.v_row_stride, params.v_head_stride, _1{}));
Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(0, 0)); // (kBlockN, kHeadDim, nblocksN)
const auto gK_data = gK.data();
const auto gV_data = gV.data();
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k_rounded, _1{}));
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutK{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtSplit = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransSplit{});
typename Kernel_traits::TiledMma16x64 tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
typename Kernel_traits::TiledMma16x32 tiled_mma_for_gemm1;
auto thr_mma_for_gemm1 = tiled_mma_for_gemm1.get_thread_slice(tidx);
Tensor tGrQ = thr_mma.partition_fragment_A(gQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(gK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma_for_gemm1.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
Tensor tSgS = thr_mma.partition_C(gP);
//
// Copy Atom retiling
//
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
// auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_B128, Element>{}, tiled_mma);
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16, Element>{}, tiled_mma_for_gemm1);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt8x64 = smem_thr_copy_V.partition_S(sVtSplit);
Tensor tOsVt = make_tensor(tOsVt8x64.data(), convert_layout_B_rowcol_<_16x64_64, kHeadDimV/32>(tOsVt8x64.layout()));
//
// PREDICATES
//
// Construct identity layout for sQ and sK
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
// Set predicates for k bounds
if constexpr (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
}
// Prologue
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_Q, tSgQ, tGrQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM);
int num_blks = params.block_count[(bidb * params.h + bidh) * params.NUM_ROWS + m_block];
auto* blks_ptr = params.block_offset + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_S;
Tensor acc_o = partition_fragment_C(tiled_mma_for_gemm1, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // MMA, MMA_M, MMA_K
clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax;
flash::Mask<false, false, false> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, 0.0);
constexpr int n_masking_steps = 1;
constexpr int k0_loops = size<2>(tSsK);
constexpr int k1_loops = size<2>(tOsVt);
static_assert(k0_loops == 2 && k1_loops == 4);
if (num_blks <= 0) {
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d_value),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<Element>(shape(tOgO));
clear(tOrO);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d_value; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgO); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
}
return;
}
int block_index = num_blks - 1;
int actual_block = blks_ptr[block_index];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
#pragma unroll
for (int i = 0; i < k0_loops; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, i, params.k_row_stride, params.d, binfo.actual_seqlen_k - actual_block);
}
#pragma unroll
for (int i = 0; i < k1_loops; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gV, sV, i, params.v_row_stride, params.d, binfo.actual_seqlen_k - actual_block);
}
#if 1
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps && block_index >= 0; ++masking_step, --block_index) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
s_waitcnt<5>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0);
s_waitcnt<4>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1);
s_barrier();
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// float * tmp = reinterpret_cast<float*>(acc_s.data());
// printf("acc_s tid:%3d n_block:%d %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f\n", tidx, n_block,
// tmp[0], tmp[1], tmp[2], tmp[3],
// tmp[4], tmp[5], tmp[6], tmp[7],
// tmp[8], tmp[9], tmp[10], tmp[11],
// tmp[12], tmp[13], tmp[14], tmp[15],
// tmp[16], tmp[17], tmp[18], tmp[19],
// tmp[20], tmp[21], tmp[22], tmp[23],
// tmp[24], tmp[25], tmp[26], tmp[27],
// tmp[28], tmp[29], tmp[30], tmp[31]
// );
// }
#if 1
{
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = m_block * kBlockM + row_idx_offset_in_block;
mask.template apply_mask_continuous<false, Is_even_MN>(
acc_s, actual_block, row_idx_offset_, (kNWarps << 4)
);
}
// wangaq debug
// __syncthreads();
// if (thread0() && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// Element * tmp = reinterpret_cast<Element*>(sV.data().get());
// int col = 32;
// for (int i = 0; i < size(sV)/col; ++i) {
// printf("V:%d nblock:%d ", i, n_block);
// for (int j = 0; j < col; ++j) {
// printf("%.4f ", float(tmp[i*col+j]));
// }
// printf("\n");
// }
// }
softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/false>(acc_s, acc_o, params.scale_softmax_log2);
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// float * tmp = reinterpret_cast<float*>(acc_s.data());
// printf("exp_s tid:%d n_block:%d row_max:%10.4f %10.4f row_sum:%10.4f %10.4f | %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f\n", tidx, n_block, softmax.row_max(0), softmax.row_max(1), softmax.row_sum(0), softmax.row_sum(1),
// tmp[0], tmp[1], tmp[2], tmp[3],
// tmp[4], tmp[5], tmp[6], tmp[7],
// tmp[8], tmp[9], tmp[10], tmp[11],
// tmp[12], tmp[13], tmp[14], tmp[15],
// tmp[16], tmp[17], tmp[18], tmp[19],
// tmp[20], tmp[21], tmp[22], tmp[23],
// tmp[24], tmp[25], tmp[26], tmp[27],
// tmp[28], tmp[29], tmp[30], tmp[31]
// );
// }
// Convert acc_s from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(acc_s);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
if (block_index > 0) {
actual_block = blks_ptr[block_index - 1];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
lds_direct_copy<Is_even_K>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 1, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 0, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 1, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 2, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 3, params.v_row_stride, params.d);
}
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// float * tmp = reinterpret_cast<float*>(acc_o.data());
// printf("acc_o tid:%d n_block:%d %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f\n", tidx, n_block,
// tmp[0], tmp[1], tmp[2], tmp[3],
// tmp[4], tmp[5], tmp[6], tmp[7],
// tmp[8], tmp[9], tmp[10], tmp[11],
// tmp[12], tmp[13], tmp[14], tmp[15],
// tmp[16], tmp[17], tmp[18], tmp[19],
// tmp[20], tmp[21], tmp[22], tmp[23],
// tmp[24], tmp[25], tmp[26], tmp[27],
// tmp[28], tmp[29], tmp[30], tmp[31]
// );
// }
#endif
}
#pragma unroll
for (; block_index >= 0; --block_index) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
s_waitcnt<5>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0);
s_waitcnt<4>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1);
s_barrier();
// __builtin_amdgcn_sched_barrier(1);
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// float * tmp = reinterpret_cast<float*>(acc_s.data());
// printf("acc_s tid:%d n_block:%d %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f\n", tidx, n_block,
// tmp[0], tmp[1], tmp[2], tmp[3],
// tmp[4], tmp[5], tmp[6], tmp[7],
// tmp[8], tmp[9], tmp[10], tmp[11],
// tmp[12], tmp[13], tmp[14], tmp[15],
// tmp[16], tmp[17], tmp[18], tmp[19],
// tmp[20], tmp[21], tmp[22], tmp[23],
// tmp[24], tmp[25], tmp[26], tmp[27],
// tmp[28], tmp[29], tmp[30], tmp[31]
// );
// }
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/false>(acc_s, acc_o, params.scale_softmax_log2);
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// float * tmp = reinterpret_cast<float*>(acc_s.data());
// printf("exp_s tid:%d n_block:%d row_max:%10.4f %10.4f row_sum:%10.4f %10.4f | %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f\n", tidx, n_block, softmax.row_max(0), softmax.row_max(1), softmax.row_sum(0), softmax.row_sum(1),
// tmp[0], tmp[1], tmp[2], tmp[3],
// tmp[4], tmp[5], tmp[6], tmp[7],
// tmp[8], tmp[9], tmp[10], tmp[11],
// tmp[12], tmp[13], tmp[14], tmp[15],
// tmp[16], tmp[17], tmp[18], tmp[19],
// tmp[20], tmp[21], tmp[22], tmp[23],
// tmp[24], tmp[25], tmp[26], tmp[27],
// tmp[28], tmp[29], tmp[30], tmp[31]
// );
// }
Tensor rP = flash::convert_type<Element>(acc_s);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
if (block_index > 0) {
actual_block = blks_ptr[block_index - 1];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
lds_direct_copy<Is_even_K>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 1, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 0, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 1, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 2, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 3, params.v_row_stride, params.d);
}
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// float * tmp = reinterpret_cast<float*>(acc_o.data());
// printf("acc_o tid:%d n_block:%d %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f\n", tidx, n_block,
// tmp[0], tmp[1], tmp[2], tmp[3],
// tmp[4], tmp[5], tmp[6], tmp[7],
// tmp[8], tmp[9], tmp[10], tmp[11],
// tmp[12], tmp[13], tmp[14], tmp[15],
// tmp[16], tmp[17], tmp[18], tmp[19],
// tmp[20], tmp[21], tmp[22], tmp[23],
// tmp[24], tmp[25], tmp[26], tmp[27],
// tmp[28], tmp[29], tmp[30], tmp[31]
// );
// }
}
// Epilogue
Tensor lse = softmax.template normalize_softmax_lse<false>(acc_o, params.scale_softmax, params.rp_dropout);
// S_BARRIER;
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// float * tmp = reinterpret_cast<float*>(acc_o.data());
// printf("acc_o tid:%d n_block:%d 0:%10.4f 1:%10.4f 2:%10.4f 3:%10.4f 4:%10.4f 5:%10.4f 6:%10.4f 7:%10.4f "
// "8:%10.4f 9:%10.4f 10:%10.4f 11:%10.4f 12:%10.4f 13:%10.4f 14:%10.4f 15:%10.4f "
// "16:%10.4f 17:%10.4f 18:%10.4f 19:%10.4f 20:%10.4f 21:%10.4f 22:%10.4f 23:%10.4f "
// "24:%10.4f 25:%10.4f 26:%10.4f 27:%10.4f 28:%10.4f 29:%10.4f 30:%10.4f 31:%10.4f\n", tidx, n_block,
// // tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5], tmp[6], tmp[7],
// // tmp[8], tmp[9], tmp[10], tmp[11], tmp[12], tmp[13], tmp[14], tmp[15],
// // tmp[16], tmp[17], tmp[18], tmp[19], tmp[20], tmp[21], tmp[22], tmp[23],
// // tmp[24], tmp[25], tmp[26], tmp[27], tmp[28], tmp[29], tmp[30], tmp[31],
// tmp[32], tmp[33], tmp[34], tmp[35], tmp[36], tmp[37], tmp[38], tmp[30],
// tmp[40], tmp[41], tmp[42], tmp[43], tmp[44], tmp[45], tmp[46], tmp[47],
// tmp[48], tmp[49], tmp[50], tmp[51], tmp[52], tmp[53], tmp[54], tmp[55],
// tmp[56], tmp[57], tmp[58], tmp[59], tmp[60], tmp[61], tmp[62], tmp[63]
// );
// }
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d_value),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}
int row, col;
const int warpId = tidx / 64;
const int laneId = tidx % 64;
#pragma unroll
for (int mi = 0; mi < size<1>(acc_o); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_o); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_o); ++ei) {
// wangaq debug
// if(thread(0, 0)) {
// printf("mi:%d ni:%d ei:%d row:%d col:%d acc_o:%.4f\n",
// mi, ni, ei, row, col, acc_o(ei, mi, ni));
// }
if (Is_even_K || col < params.d_value) {
gO(row, col) = flash::convert_type<Element>(acc_o(ei, mi, ni));
}
// else
// gO(row, col) = Element(0.0);
col += 4;
}
}
}
}
#endif
}
template<typename Kernel_traits, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void sparse_attn_1rowblock_sla_fp8(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
using ElementOUT = typename Kernel_traits::ElementO;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int kStages = Kernel_traits::kStages;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
const index_t binfo_o_offset = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_block_min = 0;
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim)128,128
Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.k_row_stride, params.k_head_stride, _1{}));
Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{}, make_coord(0, 0)); // (kBlockN, kHeadDim, nblocksN)
Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.v_row_stride, params.v_head_stride, _1{}));
Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{}, make_coord(0, 0)); // (kBlockN, kHeadDim, nblocksN)
const auto gK_data = gK.data();
const auto gV_data = gV.data();
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{}, make_stride(params.seqlen_k_rounded, _1{}));
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + 8192, typename Kernel_traits::SmemLayoutV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
// Tensor sVtSplit = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransSplit{});
typename Kernel_traits::TiledMma16x64 tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
typename Kernel_traits::TiledMma16x32 tiled_mma_for_gemm1;
auto thr_mma_for_gemm1 = tiled_mma_for_gemm1.get_thread_slice(tidx);
Tensor tGrQ = thr_mma.partition_fragment_A(gQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(gK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma_for_gemm1.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
Tensor tSgS = thr_mma.partition_C(gP);
//
// Copy Atom retiling
//
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
// auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_B128, Element>{}, tiled_mma);
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
typename Kernel_traits::TiledMma16x64BLayout tiled_mma_BLayout;
auto smem_tiled_copy_BLayout = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma_BLayout);
auto smem_thr_copy_BLayout = smem_tiled_copy_BLayout.get_thread_slice(tidx);
Tensor sKtemp = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutK{});
Tensor tSsKBLayout = smem_thr_copy_BLayout.partition_S(sKtemp);
Tensor tSsK = make_tensor(tSsKBLayout.data(), convert_layout_B_rowcol_fp8<_64x64, 128/64>(tSsKBLayout.layout()));
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x32_B8, Element>{}, tiled_mma_for_gemm1);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt8x64 = smem_thr_copy_V.partition_S(sVt);
Tensor tOsVt = make_tensor(tOsVt8x64.data(), convert_layout_B_rowcol_fp8<_32x128>(tOsVt8x64.layout()));
//
// PREDICATES
//
// Construct identity layout for sQ and sK
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
// Set predicates for k bounds
if constexpr (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
}
// Prologue
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_Q, tSgQ, tGrQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM);
int n_block = n_block_max - 1;
int num_blks = params.block_count[(bidb * params.h + bidh) * params.NUM_ROWS + m_block];
auto* blks_ptr = params.block_offset + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_S;
Tensor acc_o = partition_fragment_C(tiled_mma_for_gemm1, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // MMA, MMA_M, MMA_K
clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax;
flash::Mask<false, false, false> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, 0.0);
if (num_blks <= 0) {
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<ElementOUT*>(params.o_ptr) + binfo_o_offset),
make_shape(binfo.actual_seqlen_q, params.h, params.d_value),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<ElementOUT>(shape(tOgO));
clear(tOrO);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d_value; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgO); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
}
return;
}
float q_descale = params.q_descale_ptr == nullptr ? 1.0f : params.q_descale_ptr[bidb * params.q_descale_batch_stride + bidh * params.q_descale_head_stride];
float k_descale = params.k_descale_ptr == nullptr ? 1.0f : params.k_descale_ptr[bidb * params.k_descale_batch_stride + bidh * params.k_descale_head_stride];
float v_descale = params.v_descale_ptr == nullptr ? 1.0f : params.v_descale_ptr[bidb * params.v_descale_batch_stride + bidh * params.v_descale_head_stride];
float scale_softmax_log2 = params.scale_softmax_log2*q_descale*k_descale;
float scale_softmax = params.scale_softmax*q_descale*k_descale;
constexpr int n_masking_steps = 1;
int block_index = num_blks - 1;
int actual_block = blks_ptr[block_index];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
Tensor tCrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
lds_direct_copy_fp8<Is_even_K, Is_even_MN, _64x64>(gK, sK, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_fp8<Is_even_K, Is_even_MN, _64x64>(gK, sK, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_fp8<Is_even_K, Is_even_MN, _32x128>(gV, sV, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_fp8<Is_even_K, Is_even_MN, _32x128>(gV, sV, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
#if 1
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps && block_index >= 0; ++masking_step, --block_index) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
s_waitcnt<3>();
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tCrK_copy_view(_, _, 0));
cute::gemm(tiled_mma, tGrQ(_, _, 0), tSrK(_, _, 0), acc_s_ori);
s_waitcnt<2>();
cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tCrK_copy_view(_, _, 1));
cute::gemm(tiled_mma, tGrQ(_, _, 1), tSrK(_, _, 1), acc_s_ori);
s_barrier();
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
#if 1
{
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = m_block * kBlockM + row_idx_offset_in_block;
mask.template apply_mask_continuous_fp8<false, Is_even_MN>(
acc_s, n_block * kBlockN, row_idx_offset_, (kNWarps << 4)
);
}
softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/false>(acc_s, acc_o, scale_softmax_log2);
// Convert acc_s from fp32 to fp16/bf16
Tensor rP = flash::convert_type_fp8<Element>(acc_s);
s_waitcnt<1>();
cute::copy(smem_tiled_copy_V, tOsVt(_, _, 0), tOrVt_copy_view(_, _, 0));
cute::gemm(tiled_mma_for_gemm1, rP(_, _, 0), tOrVt(_, _, 0), acc_o);
s_waitcnt<0>();
cute::copy(smem_tiled_copy_V, tOsVt(_, _, 1), tOrVt_copy_view(_, _, 1));
cute::gemm(tiled_mma_for_gemm1, rP(_, _, 1), tOrVt(_, _, 1), acc_o);
s_barrier();
if (block_index > 0) {
actual_block = blks_ptr[block_index - 1];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
lds_direct_copy_fp8<Is_even_K, true, _64x64>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy_fp8<Is_even_K, true, _64x64>(gK, sK, 1, params.k_row_stride, params.d);
lds_direct_copy_fp8<Is_even_K, true, _32x128>(gV, sV, 0, params.v_row_stride, params.d_value);
lds_direct_copy_fp8<Is_even_K, true, _32x128>(gV, sV, 1, params.v_row_stride, params.d_value);
}
#endif
}
#endif
#if 1
#pragma unroll
for (; block_index >= 0; --block_index) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
s_waitcnt<3>();
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tCrK_copy_view(_, _, 0));
cute::gemm(tiled_mma, tGrQ(_, _, 0), tSrK(_, _, 0), acc_s_ori);
s_waitcnt<2>();
cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tCrK_copy_view(_, _, 1));
cute::gemm(tiled_mma, tGrQ(_, _, 1), tSrK(_, _, 1), acc_s_ori);
s_barrier();
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/false>(acc_s, acc_o, scale_softmax_log2);
Tensor rP = flash::convert_type_fp8<Element>(acc_s);
s_waitcnt<1>();
cute::copy(smem_tiled_copy_V, tOsVt(_, _, 0), tOrVt_copy_view(_, _, 0));
cute::gemm(tiled_mma_for_gemm1, rP(_, _, 0), tOrVt(_, _, 0), acc_o);
s_waitcnt<0>();
cute::copy(smem_tiled_copy_V, tOsVt(_, _, 1), tOrVt_copy_view(_, _, 1));
cute::gemm(tiled_mma_for_gemm1, rP(_, _, 1), tOrVt(_, _, 1), acc_o);
s_barrier();
if (block_index > 0) {
const int actual_block = blks_ptr[block_index - 1];
gK.data() = gK_data + actual_block * int64_t(params.k_row_stride);
gV.data() = gV_data + actual_block * int64_t(params.v_row_stride);
lds_direct_copy_fp8<Is_even_K, true, _64x64>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy_fp8<Is_even_K, true, _64x64>(gK, sK, 1, params.k_row_stride, params.d);
lds_direct_copy_fp8<Is_even_K, true, _32x128>(gV, sV, 0, params.v_row_stride, params.d_value);
lds_direct_copy_fp8<Is_even_K, true, _32x128>(gV, sV, 1, params.v_row_stride, params.d_value);
}
}
#endif
#if 1
// Epilogue
Tensor lse = softmax.template normalize_softmax_lse_fp8<false>(acc_o, scale_softmax,v_descale, params.rp_dropout);
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<ElementOUT*>(params.o_ptr) + binfo_o_offset),
make_shape(binfo.actual_seqlen_q, params.h, params.d_value),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
if (get<1>(taccOcO(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO(0, mi, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}
using result_type = cutlass::Array<bfloat16_t, 2>;
int row, col;
const int warpId = tidx / 64;
const int laneId = tidx % 64;
#pragma unroll
for (int mi = 0; mi < size<1>(acc_o); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_o); ++ni) {
col = (laneId / 16)*4 + ni * 32;
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(0, mi, ni), 0, acc_o(1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(2, mi, ni), 0, acc_o(3, mi, ni), 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
gO(row, col + 1) = res0[1];
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
col += 16;
}
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(4, mi, ni), 0, acc_o(5, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(6, mi, ni), 0, acc_o(7, mi, ni), 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
gO(row, col + 1) = res0[1];
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
}
}
}
}
#endif
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
inline __device__ void compute_sparse_attn(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
// them to have the same number of threads or have to traverse the attention matrix
// in the same order.
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
// (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
#if (defined(__gfx936__) || defined(__gfx938__) )
flash::sparse_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
#endif
}
#if 1
template<typename Kernel_traits, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_sparse_attn_sla(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
#if (defined(__gfx936__) || defined(__gfx938__) )
if constexpr (Kernel_traits::kHeadDim == 128) {
flash::sparse_attn_1rowblock_sla<Kernel_traits, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
} else if constexpr (Kernel_traits::kHeadDim == 64) {
flash::sparse_attn_1rowblock_sla_dim64<Kernel_traits, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
}
#endif
}
#endif
#if 1
template<typename Kernel_traits, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_sparse_attn_sla_fp8(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
#if defined(__gfx938__)
flash::sparse_attn_1rowblock_sla_fp8<Kernel_traits, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
#endif
}
#endif
} // namespace FLASH_NAMESPACE
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2024, PAI, Alibaba Cloud.
******************************************************************************/
#pragma once
#include "flash_fwd_launch_template.h"
#include "flash_fwd_sparse_kernel.h"
#include "flash_sparse.h"
// namespace FLASH_NAMESPACE {
#define DEFINE_FLASH_FORWARD_SPARSE_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params_sparse params)
DEFINE_FLASH_FORWARD_SPARSE_KERNEL(flash_fwd_sparse_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_sparse_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
#if 1
DEFINE_FLASH_FORWARD_SPARSE_KERNEL(flash_fwd_sparse_sla_kernel, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_sparse_attn_sla<Kernel_traits, Is_even_MN, Is_even_K, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
#endif
#if 1
DEFINE_FLASH_FORWARD_SPARSE_KERNEL(flash_fwd_sparse_sla_fp8_kernel, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_sparse_attn_sla_fp8<Kernel_traits, Is_even_MN, Is_even_K, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
#endif
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_sparse_fwd(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.h, params.b);
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
constexpr bool IsEvenMNConst = false;
constexpr bool Is_local = false;
auto kernel = &flash_fwd_sparse_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
}
#if 1
template<typename Kernel_traits>
void run_flash_sparse_sla_fwd(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.h, params.b);
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
constexpr bool IsEvenMNConst = false;
auto kernel = &flash_fwd_sparse_sla_kernel<Kernel_traits, IsEvenMNConst, IsEvenKConst, ReturnSoftmaxConst>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
#endif
#if 1
template<typename Kernel_traits>
void run_flash_sparse_sla_fwd_fp8(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.h, params.b);
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
constexpr bool IsEvenMNConst = false;
auto kernel = &flash_fwd_sparse_sla_fp8_kernel<Kernel_traits, IsEvenMNConst, IsEvenKConst, ReturnSoftmaxConst>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
#endif
template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim32(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim64(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T>
void run_mha_fwd_sparse_sla_hdim64(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (params.seqlen_q <= 2048)
run_flash_sparse_sla_fwd<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 64, 64, 4, T>>(params, stream);
else
run_flash_sparse_sla_fwd<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 128, 64, 4, T>>(params, stream);
}
}
template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim96(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim128(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_sparse_fwd<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T>
void run_mha_fwd_sparse_sla_hdim128(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (params.seqlen_q <= 2048)
run_flash_sparse_sla_fwd<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 64, 64, 4, T, 3>>(params, stream);
else
run_flash_sparse_sla_fwd<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 128, 64, 4, T, 3>>(params, stream);
}
}
template<typename T>
void run_mha_fwd_sparse_sla_hdim128_fp8(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
if (get_device_name() == "gfx938") {
// int num_blocks_64 = params.h * params.b * ((params.seqlen_q + 64 - 1) / 64);//3
// int num_blocks_128 = params.h * params.b * ((params.seqlen_q + 128 - 1) / 128);//2
// if ((num_blocks_64 <= sm_count || (num_blocks_128 / sm_count == 1 && num_blocks_128 % sm_count > 1 && (num_blocks_64 + sm_count - 1) / sm_count <= 3) || force_blockm64) && !force_blockm128) {
if (params.seqlen_q <= 2048) {
run_flash_sparse_sla_fwd_fp8<Flash_fwd_kernel_16x64_prefetch_traits_fp8<Headdim, 64, 64, 4, T,T_out, 3>>(params, stream);
} else {
run_flash_sparse_sla_fwd_fp8<Flash_fwd_kernel_16x64_prefetch_traits_fp8<Headdim, 128, 64, 4, T,T_out, 3>>(params, stream);
}
} else {
printf("this device is not supoort fp8");
}
}
template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim160(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim192(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim224(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 224;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T, bool Is_causal>
void run_mha_fwd_sparse_hdim256(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
// } // namespace FLASH_NAMESPACE
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch_fp8<cutlass::float_e4m3_t,cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch_fp8<cutlass::float_e5m2_t,cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch_fp8<cutlass::float_e5m2_t,cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch_fp8<cutlass::float_e4m3_t,cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment