#include "interface.h" #include #include #include #include "common/mask.cuh" #include "common/utils.hpp" #include "fmha_cutlass_bwd_sm100.cuh" template void call_run_fmha_bwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, [[maybe_unused]] Mla mla, at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, float softmax_scale, int max_seqlen_q, int total_seqlen_kv) { static constexpr bool IsVarlen = std::is_same_v; static constexpr bool IsMla = std::is_same_v; using TileShape = std::conditional_t, Shape<_128, _128, _128, _128>>; run_fmha_bwd(workspace_buffer, d_o, q, k, v, o, lse, cumulative_seqlen_q, cumulative_seqlen_kv, dq, dk, dv, softmax_scale, max_seqlen_q, total_seqlen_kv); } void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen) { const c10::cuda::OptionalCUDAGuard device_guard(q.device()); int head_dim_qk = q.size(-1); int head_dim_vo = v.size(-1); MaskMode mask_mode = static_cast(mask_mode_code); auto scalar_type_in = q.scalar_type(); auto scalar_type_out = o.scalar_type(); if(scalar_type_in == at::ScalarType::BFloat16 && scalar_type_out == at::ScalarType::BFloat16) { using Element = cutlass::bfloat16_t; using ElementOut = cutlass::bfloat16_t; auto apply_config = [&](auto fn) { if (mask_mode == MaskMode::kCausal) { if(is_varlen) { fn(CausalForBackwardMask{}, cute::true_type{}, Element{}, ElementOut{}); } else { fn(CausalForBackwardMask{}, cute::false_type{}, Element{}, ElementOut{}); } } else { if(is_varlen) { fn(ResidualMaskForBackward{}, cute::true_type{}, Element{}, ElementOut{}); } else { fn(ResidualMaskForBackward{}, cute::false_type{}, Element{}, ElementOut{}); } } }; apply_config([&](auto mask, auto varlen, auto in, auto out) { if (head_dim_qk == 192 && head_dim_vo == 128) { call_run_fmha_bwd(mask, varlen, in, out, true_type{}, workspace_buffer, d_o, q, k, v, o, lse, cumulative_seqlen_q, cumulative_seqlen_kv, dq, dk, dv, softmax_scale, max_seqlen_q, max_seqlen_kv); } else if (head_dim_qk == 128 && head_dim_vo == 128) { call_run_fmha_bwd(mask, varlen, in, out, false_type{}, workspace_buffer, d_o, q, k, v, o, lse, cumulative_seqlen_q, cumulative_seqlen_kv, dq, dk, dv, softmax_scale, max_seqlen_q, max_seqlen_kv); } else { std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk << " head_dim_vo=" << head_dim_vo << std::endl; } }); } else { FLASH_MLA_ASSERT(false); } }