/****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // modify from: https://github.com/Dao-AILab/flash-attention #include "flash.h" #include "src/turbomind/models/llama/llama_kernels.h" #include "static_switch.h" #include #include #include void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { FWD_HEADDIM_SWITCH(params.d, [&] { run_mha_fwd_(params, stream); }); }); } namespace turbomind { static constexpr int FMHA_VERSION = 2; template class FlashAttentionOpImpl { public: using AttentionLayout = BaseAttentionLayout; using Params = BaseAttentionParams; public: FlashAttentionOpImpl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head); ~FlashAttentionOpImpl(); int get_workspace_size() const; void operator()(Params& params, cudaStream_t st) const; private: class impl; std::unique_ptr pimpl; }; template class FlashAttentionOpImpl::impl { private: using scalar_t = typename std::conditional_t::type>::value, cutlass::half_t, T>; using Params = typename FlashAttentionOpImpl::Params; int batch_size_; int head_num_; int key_len_; int seq_len_; int size_per_head_; public: impl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head): batch_size_(batch_size), head_num_(head_num), key_len_(key_len), seq_len_(seq_len), size_per_head_(size_per_head) { } ~impl() {} int get_workspace_size() const { return 0; } void operator()(Params& params, cudaStream_t st) const { const float qk_scale = static_cast(1.f / sqrtf(size_per_head_ * 1.f)); Flash_fwd_params fwd_params; memset(&fwd_params, 0, sizeof(fwd_params)); fwd_params.q_ptr = reinterpret_cast(params.query); fwd_params.k_ptr = reinterpret_cast(params.key); fwd_params.v_ptr = reinterpret_cast(params.val); fwd_params.k_batched_ptr = reinterpret_cast(params.layout_k.batch_seqs); fwd_params.v_batched_ptr = reinterpret_cast(params.layout_v.batch_seqs); fwd_params.k_batched_offset = params.layout_k.batch_seqs_offset; fwd_params.v_batched_offset = params.layout_v.batch_seqs_offset; fwd_params.q_batch_stride = params.layout_q.stride_batch; fwd_params.k_batch_stride = params.layout_k.stride_batch; fwd_params.v_batch_stride = params.layout_v.stride_batch; fwd_params.q_row_stride = params.layout_q.stride_seq; fwd_params.k_row_stride = params.layout_k.stride_seq; fwd_params.v_row_stride = params.layout_v.stride_seq; fwd_params.q_head_stride = params.layout_q.stride_head; fwd_params.v_head_stride = params.layout_v.stride_head; fwd_params.k_head_stride = params.layout_k.stride_head; fwd_params.h = head_num_; fwd_params.h_k = head_num_ / params.group_size; fwd_params.h_h_k_ratio = params.group_size; fwd_params.o_ptr = reinterpret_cast(params.attn_out); fwd_params.o_batch_stride = params.layout_o.stride_batch; fwd_params.o_row_stride = params.layout_o.stride_seq; fwd_params.o_head_stride = params.layout_o.stride_head; fwd_params.p_ptr = nullptr; fwd_params.b = batch_size_; fwd_params.seqlen_q = seq_len_; fwd_params.seqlen_k = key_len_; fwd_params.d = size_per_head_; fwd_params.seqlen_q_rounded = 0; fwd_params.seqlen_k_rounded = 0; fwd_params.scale_softmax = qk_scale; fwd_params.scale_softmax_log2 = qk_scale * M_LOG2E; fwd_params.cu_seqlens_q = params.cu_seqlens_q; fwd_params.cu_seqlens_k = params.cu_seqlens_k; fwd_params.actual_seqlen_q = params.actual_seqlen_q; fwd_params.actual_seqlen_k = params.actual_seqlen_k; fwd_params.blockmask = reinterpret_cast(params.mask); #ifdef ENABLE_BF16 fwd_params.is_bf16 = std::is_same::value; #else fwd_params.is_bf16 = false; #endif fwd_params.is_causal = true; fwd_params.q_enable_seqlen = params.layout_q.use_seqlens; fwd_params.o_enable_seqlen = params.layout_o.use_seqlens; run_mha_fwd(fwd_params, st); } }; template FlashAttentionOpImpl::FlashAttentionOpImpl( int batch_size, int head_num, int key_len, int seq_len, int size_per_head): pimpl{std::make_unique::impl>( batch_size, head_num, key_len, seq_len, size_per_head)} { } template FlashAttentionOpImpl::~FlashAttentionOpImpl() { } template int FlashAttentionOpImpl::get_workspace_size() const { return pimpl->get_workspace_size(); } template void FlashAttentionOpImpl::operator()(Params& params, cudaStream_t st) const { pimpl->operator()(params, st); } template class FlashAttentionOpImpl; template class FlashAttentionOpImpl; #ifdef ENABLE_BF16 template class FlashAttentionOpImpl<__nv_bfloat16, FMHA_VERSION>; #endif } // namespace turbomind