Commit e2e0225c authored by zhanghj2's avatar zhanghj2
Browse files

空kernel可以编译通过

parent 48c6dc42
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cute/tensor.hpp"
#include "../device/fmha.hpp"
#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp"
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
#include "../kernel/fmha_kernel_bwd_convert.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class Element,
class ElementAccumulator,
class TileShape,
bool IsMla,
class Mask
>
class Sm100FmhaBwd {
public:
/// Argument structure: User API
struct Arguments {
// Q K D D_VO HB
ProblemShape problem_shape;
const Element* ptr_Q;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
const Element* ptr_K;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
const Element* ptr_V;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
const ElementAccumulator* ptr_LSE;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
Element* ptr_dQ;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
Element* ptr_dK;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
Element* ptr_dV;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
ElementAccumulator softmax_scale;
cutlass::KernelHardwareInfo hw_info;
};
using OperationSumOdO = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<ProblemShape, Element, ElementAccumulator>
>;
using OperationConvert = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
>;
using OperationMha= cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
>
>;
using OperationMla = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
>
>;
using Operation = std::conditional_t<IsMla, OperationMla, OperationMha>;
using Kernel = typename Operation::Kernel;
struct Params {
OperationSumOdO op_sum_OdO;
Operation op;
OperationConvert op_convert;
ElementAccumulator* dQ_acc;
size_t dQ_acc_size;
};
private:
Params params_;
static typename OperationSumOdO::Arguments to_sum_OdO_arguments(
Arguments const& args,
ElementAccumulator* sum_odo = nullptr,
ElementAccumulator* scaled_lse = nullptr) {
using namespace cute;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
auto log2_e = log2f(expf(1.0f));
return typename OperationSumOdO::Arguments {
args.problem_shape,
args.ptr_O, args.stride_O,
args.ptr_dO, args.stride_dO,
sum_odo, stride_sum_OdO,
args.ptr_LSE, args.stride_LSE,
scaled_lse, stride_scaled_lse,
-1.0f, -log2_e
};
}
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
using namespace cute;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
return typename OperationConvert::Arguments {
args.problem_shape,
src, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, stride_src_dQ,
args.ptr_dQ, args.stride_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV,
args.softmax_scale
};
}
static typename Operation::Arguments to_bwd_arguments(
Arguments const& args,
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
return typename Operation::Arguments{
args.problem_shape,
{ args.ptr_Q, args.stride_Q,
args.ptr_K, args.stride_K,
args.ptr_V, args.stride_V,
args.ptr_dO, args.stride_dO,
scaled_lse, stride_scaled_lse,
sum_OdO, stride_sum_OdO,
dQ_acc, stride_dQ,
args.softmax_scale },
{ args.ptr_dK, args.stride_dK,
args.ptr_dV, args.stride_dV },
args.hw_info
};
}
public:
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
Status status = Status::kSuccess;
status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));
if (status != Status::kSuccess) {
return status;
}
status = OperationConvert::can_implement(to_convert_arguments(args));
if (status != Status::kSuccess) {
return status;
}
status = Operation::can_implement(to_bwd_arguments(args));
if (status != Status::kSuccess) {
return status;
}
return status;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
size_t workspace_bytes = 0;
// OdO vector
workspace_bytes += sizeof(ElementAccumulator) * B*H*Q;
// scaled LSE vector
workspace_bytes += sizeof(ElementAccumulator) * B*H*Q;
// FP32 versions of outputs that are churned (start off with Q only)
workspace_bytes += sizeof(ElementAccumulator) * B*H*Q*D;
return workspace_bytes;
}
/// Initializes state from arguments.
Status
initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
params_.dQ_acc = dQ_acc;
params_.dQ_acc_size = sizeof(ElementAccumulator) * B*H*Q*D;
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse);
auto args_convert = to_convert_arguments(args, dQ_acc);
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
params_.op_convert.initialize(args_convert, nullptr, stream);
auto args_bwd = to_bwd_arguments(
args, sum_OdO, args_sum_OdO.stride_sum_OdO,
scaled_lse, args_sum_OdO.stride_scaled_lse,
dQ_acc, args_convert.stride_src_dQ
);
params_.op.initialize(args_bwd, nullptr, stream);
return Status::kSuccess;
}
/// Initializes state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += sizeof(ElementAccumulator) * B*H*Q;
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += sizeof(ElementAccumulator) * B*H*Q;
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream);
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
Status result = Status::kSuccess;
result = params.op_sum_OdO.run(stream);
if (result != Status::kSuccess) {
return result;
}
auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);
if (cuda_result != cudaSuccess) {
return Status::kErrorInternal;
}
result = params.op.run(stream);
if (result != Status::kSuccess) {
return result;
}
result = params.op_convert.run(stream);
if (result != Status::kSuccess) {
return result;
}
return Status::kSuccess;
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////
#include "interface.h"
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_bf16.h>
#include "common/mask.cuh"
#include "common/utils.hpp"
#include "fmha_cutlass_bwd_sm100.cuh"
template<class Mask, class Varlen, class Element, class ElementOut, class Mla>
void call_run_fmha_bwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen,
[[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, [[maybe_unused]] Mla mla,
at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
float softmax_scale, int max_seqlen_q, int total_seqlen_kv) {
static constexpr bool IsVarlen = std::is_same_v<Varlen, true_type>;
static constexpr bool IsMla = std::is_same_v<Mla, true_type>;
using TileShape = std::conditional_t<IsMla, Shape<_64, _128, _192, _128>, Shape<_128, _128, _128, _128>>;
run_fmha_bwd<Element, IsVarlen, IsMla, TileShape, Mask>(workspace_buffer, d_o, q, k, v, o, lse,
cumulative_seqlen_q, cumulative_seqlen_kv,
dq, dk, dv,
softmax_scale, max_seqlen_q, total_seqlen_kv);
}
void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen) {
const c10::cuda::OptionalCUDAGuard device_guard(q.device());
int head_dim_qk = q.size(-1);
int head_dim_vo = v.size(-1);
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
auto scalar_type_in = q.scalar_type();
auto scalar_type_out = o.scalar_type();
if(scalar_type_in == at::ScalarType::BFloat16 && scalar_type_out == at::ScalarType::BFloat16) {
using Element = cutlass::bfloat16_t;
using ElementOut = cutlass::bfloat16_t;
auto apply_config = [&](auto fn) {
if (mask_mode == MaskMode::kCausal) {
if(is_varlen) {
fn(CausalForBackwardMask<false>{}, cute::true_type{}, Element{}, ElementOut{});
} else {
fn(CausalForBackwardMask<false>{}, cute::false_type{}, Element{}, ElementOut{});
}
}
else {
if(is_varlen) {
fn(ResidualMaskForBackward{}, cute::true_type{}, Element{}, ElementOut{});
} else {
fn(ResidualMaskForBackward{}, cute::false_type{}, Element{}, ElementOut{});
}
}
};
apply_config([&](auto mask, auto varlen, auto in, auto out) {
if (head_dim_qk == 192 && head_dim_vo == 128) {
call_run_fmha_bwd(mask, varlen, in, out, true_type{}, workspace_buffer, d_o, q, k, v, o, lse,
cumulative_seqlen_q, cumulative_seqlen_kv,
dq, dk, dv,
softmax_scale, max_seqlen_q, max_seqlen_kv);
} else if (head_dim_qk == 128 && head_dim_vo == 128) {
call_run_fmha_bwd(mask, varlen, in, out, false_type{}, workspace_buffer, d_o, q, k, v, o, lse,
cumulative_seqlen_q, cumulative_seqlen_kv,
dq, dk, dv,
softmax_scale, max_seqlen_q, max_seqlen_kv); }
else {
std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk << " head_dim_vo=" << head_dim_vo << std::endl;
}
});
} else {
FLASH_MLA_ASSERT(false);
}
}
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <iostream>
#include <random>
#include <regex>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/kernel_hardware_info.h>
#include <cutlass/util/command_line.h>
#include <cutlass/util/distribution.h>
#include <cutlass/util/reference/device/tensor_fill.h>
#include "common/utils.hpp"
#include "collective/fmha_fusion.hpp"
#include "device/fmha_device_bwd.hpp"
using namespace cute;
using namespace cutlass::fmha::kernel;
using namespace cutlass::fmha::collective;
using namespace cutlass::fmha;
using namespace cutlass;
template<
class DType,
bool kIsVarlen,
bool kIsMla,
class TileShape,
class ActiveMask
>
struct BwdRunner {
using Element = DType;
using ElementAccumulator = float;
// Q K D D_VO (H B)
using ProblemShape = std::conditional_t<
kIsVarlen,
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<int, int>>,
cute::tuple<int, int, int, int, cute::tuple<int, int>>
>;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, kIsMla, ActiveMask>;
using TensorStride = Stride<int, _1, Stride<int, int>>;
using StrideQ = TensorStride; // Seq DQK (H B)
using StrideK = TensorStride; // Seq DQK (H B)
using StrideV = TensorStride; // Seq DVO (H B)
using StrideO = TensorStride; // Seq DVO (H B)
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
// Backwards specific
using StrideDQ = TensorStride;
using StrideDK = TensorStride; // Seq DQK (H B)
using StrideDV = TensorStride; // Seq DVO (H B)
using StrideDO = TensorStride;
static void run(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
float softmax_scale, int max_seqlen_q, int max_seqlen_kv) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
ProblemShape problem_shape;
cute::tuple<int, int, int, int, cute::tuple<int, int>> tensor_shape;
int d = q.size(-1);
int d_vo = v.size(-1);
int batch_size = cumulative_seqlen_q.size(0) - 1;
int num_qo_heads = q.size(1);
int total_seqlen_q = q.size(0);
int total_seqlen_kv = k.size(0);
//varlen: q: [Q, H, D]
//fixedlen: q: [B, H, Q, D]
if constexpr (kIsVarlen) {
problem_shape = cute::make_tuple(
VariableLength{max_seqlen_q, static_cast<int*>(cumulative_seqlen_q.data_ptr()), total_seqlen_q},
VariableLength{max_seqlen_kv, static_cast<int*>(cumulative_seqlen_kv.data_ptr()), total_seqlen_kv},
d, d_vo, cute::make_tuple(num_qo_heads, batch_size));
tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, d, d_vo, make_shape(num_qo_heads, 1));
} else {
int q_len = total_seqlen_q / batch_size;
int kv_len = total_seqlen_kv / batch_size;
problem_shape = cute::make_tuple(q_len, kv_len, d, d_vo, cute::make_tuple(num_qo_heads, batch_size));
tensor_shape = problem_shape;
}
auto [Q, K, D, D_VO, HB] = tensor_shape;
auto [H, B] = HB;
int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2);
int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2);
int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2);
int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2);
int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1);
int dq_stride0 = dq.stride(0), dq_stride1 = dq.stride(1), dq_stride2 = dq.stride(2);
int dk_stride0 = dk.stride(0), dk_stride1 = dk.stride(1), dk_stride2 = dk.stride(2);
int dv_stride0 = dv.stride(0), dv_stride1 = dv.stride(1), dv_stride2 = dv.stride(2);
int do_stride0 = d_o.stride(0), do_stride1 = d_o.stride(1), do_stride2 = d_o.stride(2);
TORCH_CHECK(q_stride2 == 1);
TORCH_CHECK(k_stride2 == 1);
TORCH_CHECK(v_stride2 == 1);
TORCH_CHECK(o_stride2 == 1);
TORCH_CHECK(lse_stride0 == 1);
TORCH_CHECK(dq_stride2 == 1);
TORCH_CHECK(dk_stride2 == 1);
TORCH_CHECK(dv_stride2 == 1);
TORCH_CHECK(do_stride2 == 1);
StrideQ stride_Q = make_stride(q_stride0, _1{}, make_stride(q_stride1, B == 1 ? 0 : q_stride0*Q));
StrideK stride_K = make_stride(k_stride0, _1{}, make_stride(k_stride1, B == 1 ? 0 : k_stride0*K));
StrideV stride_V = make_stride(v_stride0, _1{}, make_stride(v_stride1, B == 1 ? 0 : v_stride0*K));
StrideO stride_O = make_stride(o_stride0, _1{}, make_stride(o_stride1, B == 1 ? 0 : o_stride0*Q));
StrideLSE stride_LSE = make_stride(_1{}, make_stride(lse_stride1, B == 1 ? 0 : Q));
StrideDQ stride_dQ = make_stride(dq_stride0, _1{}, make_stride(dq_stride1, B == 1 ? 0 : dq_stride0*Q));
StrideDK stride_dK = make_stride(dk_stride0, _1{}, make_stride(dk_stride1, B == 1 ? 0 : dk_stride0*K));
StrideDV stride_dV = make_stride(dv_stride0, _1{}, make_stride(dv_stride1, B == 1 ? 0 : dv_stride0*K));
StrideDO stride_dO = make_stride(do_stride0, _1{}, make_stride(do_stride1, B == 1 ? 0 : do_stride0*Q));
typename Operation::Arguments arguments{
problem_shape,
(static_cast<Element*>(q.data_ptr())), stride_Q,
(static_cast<Element*>(k.data_ptr())), stride_K,
(static_cast<Element*>(v.data_ptr())), stride_V,
(static_cast<Element*>(o.data_ptr())), stride_O,
(static_cast<ElementAccumulator*>(lse.data_ptr())), stride_LSE,
(static_cast<Element*>(d_o.data_ptr())), stride_dO,
(static_cast<Element*>(dq.data_ptr())), stride_dQ,
(static_cast<Element*>(dk.data_ptr())), stride_dK,
(static_cast<Element*>(dv.data_ptr())), stride_dV,
static_cast<ElementAccumulator>(softmax_scale),
hw_info
};
Operation op;
uint8_t* workspace_ptr = static_cast<uint8_t*>(workspace_buffer.data_ptr());
CUTLASS_CHECK(op.can_implement(arguments));
CUTLASS_CHECK(op.initialize(arguments, workspace_ptr));
CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));
}
};
template <typename DType, bool kIsVarlen, bool kIsMla, typename TileShape, typename Mask>
void run_fmha_bwd(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
float softmax_scale, int max_seqlen_q, int total_seqlen_kv) {
BwdRunner<DType, kIsVarlen, kIsMla, TileShape, Mask>::run(workspace_buffer, d_o, q, k, v, o, lse,
cumulative_seqlen_q, cumulative_seqlen_kv,
dq, dk, dv,
softmax_scale, max_seqlen_q, total_seqlen_kv);
}
#include "interface.h"
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_bf16.h>
#include "common/mask.cuh"
#include "common/utils.hpp"
#include "fmha_cutlass_fwd_sm100.cuh"
template <class Mask, class Varlen, class Element, class ElementOut, class Mla>
void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen,
[[maybe_unused]] Element in, [[maybe_unused]] ElementOut out,
[[maybe_unused]] Mla mla, at::Tensor workspace_buffer, at::Tensor q,
at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q,
at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse,
float softmax_scale, int max_seqlen_q, int max_seqlen_kv) {
static constexpr bool IsVarlen = std::is_same_v<Varlen, true_type>;
static constexpr bool IsMla = std::is_same_v<Mla, true_type>;
static constexpr bool IsCausalMask = std::is_same_v<Mask, CausalMask<false>>;
using Option =
std::conditional_t<IsCausalMask || (IsVarlen), Option<Tag::kIsPersistent, false_type>,
Option<Tag::kIsPersistent, true_type>>;
run_fmha_fwd<Element, ElementOut, IsVarlen, IsMla, Mask, Option>(
workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse,
softmax_scale, max_seqlen_q, max_seqlen_kv);
}
void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor cumulative_seqlen_q,
at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse,
int mask_mode_code, float sm_scale, int max_seqlen_q,
int max_seqlen_kv, bool is_varlen) {
const c10::cuda::OptionalCUDAGuard device_guard(q.device());
CHECK(q.scalar_type() == k.scalar_type());
auto scalar_type_in = q.scalar_type();
auto scalar_type_out = o.scalar_type();
int head_dim_qk = q.size(-1);
int head_dim_vo = v.size(-1);
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
if (scalar_type_in == at::ScalarType::BFloat16 &&
scalar_type_out == at::ScalarType::BFloat16) {
using Element = cutlass::bfloat16_t;
using ElementOut = cutlass::bfloat16_t;
auto apply_config = [&](auto fn) {
if (mask_mode == MaskMode::kCausal) {
if (is_varlen) {
fn(CausalMask<false>{}, cute::true_type{}, Element{}, ElementOut{});
} else {
fn(CausalMask<false>{}, cute::false_type{}, Element{}, ElementOut{});
}
} else {
if (is_varlen) {
fn(ResidualMask{}, cute::true_type{}, Element{}, ElementOut{});
} else {
fn(ResidualMask{}, cute::false_type{}, Element{}, ElementOut{});
}
}
};
apply_config([&](auto mask, auto varlen, auto in, auto out) {
if (head_dim_qk == 192 && head_dim_vo == 128) {
call_run_fmha_fwd(mask, varlen, in, out, true_type{}, workspace_buffer, q, k, v,
cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale,
max_seqlen_q, max_seqlen_kv);
} else if (head_dim_qk == 128 && head_dim_vo == 128) {
call_run_fmha_fwd(mask, varlen, in, out, false_type{}, workspace_buffer, q, k, v,
cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale,
max_seqlen_q, max_seqlen_kv);
} else {
std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk
<< " head_dim_vo=" << head_dim_vo << std::endl;
}
});
} else {
FLASH_MLA_ASSERT(false);
}
}
#pragma once
#include "collective/fmha_fusion.hpp"
#include "collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp"
#include "collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp"
#include "collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.h"
#include "device/fmha.hpp"
#include "kernel/fmha_causal_tile_scheduler.hpp"
#include "kernel/fmha_options.hpp"
#include "kernel/fmha_tile_scheduler.hpp"
#include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp"
#include <torch/library.h>
#include <c10/cuda/CUDAStream.h>
using namespace cute;
using namespace cutlass::fmha::collective;
using namespace cutlass::fmha::kernel;
using namespace cutlass::fmha::device;
struct FmhaOptions {
int b = 1;
int h = 1;
int h_k = 1;
int q = 256;
int k = 256;
int d = 128;
};
struct MlaOptions {
int b = 1;
int h = 1;
int h_k = 1;
int q = 256;
int k = 256;
int dl = 128; // headdim latent
int dr = 64; // headdim rope
};
template <bool kIsMla, bool kIsMaskTileSchedulerValid, bool kIsVarlen, class Element_,
class ElementOut_, class ActiveMask, class... KernelOptions>
struct FwdRunner {
using Element = Element_;
using ElementAccumulatorQK = float;
using ElementAccumulatorPV = float;
using ElementOut = ElementOut_;
using HeadDimLatent = _128;
using HeadDim = Shape<HeadDimLatent, _64>;
using TileShapeMla = Shape<_256, _128, HeadDim>;
using TileShapeFmha = Shape<_256, _128, _128>;
using TileShape = std::conditional_t<kIsMla, TileShapeMla, TileShapeFmha>;
using ProblemShapeRegular = std::conditional_t<
kIsMla,
cute::tuple<int, int, cute::tuple<int, int>, cute::tuple<cute::tuple<int, int>, int>>,
cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>>;
using ProblemShapeVarlen =
std::conditional_t<kIsMla,
cute::tuple<VariableLength, VariableLength, cute::tuple<int, int>,
cute::tuple<cute::tuple<int, int>, int>>,
cute::tuple<VariableLength, VariableLength, int,
cute::tuple<cute::tuple<int, int>, int>>>;
using ProblemShapeType =
std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>;
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>;
using StrideV = StrideK;
using StrideO = StrideQ;
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>;
static constexpr bool kIsPersistent =
find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;
using TileScheduler = std::conditional_t<
kIsPersistent,
std::conditional_t<std::is_same_v<ActiveMask, CausalMask<false>> ||
std::is_same_v<ActiveMask, CausalMask<true>>,
cutlass::fmha::kernel::CausalPersistentTileScheduler,
cutlass::fmha::kernel::PersistentTileScheduler>,
std::conditional_t<kIsMaskTileSchedulerValid,
cutlass::fmha::kernel::CausalIndividualTileScheduler,
cutlass::fmha::kernel::IndividualTileScheduler>>;
static constexpr bool IsOrderLoadEpilogue =
kIsPersistent && (sizeof(Element) == sizeof(ElementOut));
using OrderLoadEpilogue = std::conditional_t<IsOrderLoadEpilogue, true_type, false_type>;
using MainloopMla = cutlass::fmha::collective::Sm100MlaFwdMainloopTmaWarpspecialized<
Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeMla, StrideQ, StrideK,
StrideV, ActiveMask, Shape<_2, _1, _1>, OrderLoadEpilogue>;
using OperationMla =
cutlass::fmha::device::FMHA<cutlass::fmha::kernel::Sm100FmhaFwdKernelTmaWarpspecialized<
ProblemShapeType, MainloopMla,
cutlass::fmha::collective::Sm100FmhaFwdEpilogueTmaWarpspecialized<
ElementOut, ElementAccumulatorPV, typename MainloopMla::TileShapePV, StrideO,
StrideLSE, OrderLoadEpilogue>,
TileScheduler, cutlass::fmha::kernel::Sm100MlaFwdCtxKernelWarpspecializedSchedule>>;
using MainloopFmha = cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized<
Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeFmha, StrideQ, StrideK,
StrideV, ActiveMask>;
using OperationFmha =
cutlass::fmha::device::FMHA<cutlass::fmha::kernel::Sm100FmhaFwdKernelTmaWarpspecialized<
ProblemShapeType, MainloopFmha,
cutlass::fmha::collective::Sm100FmhaFwdEpilogueTmaWarpspecialized<
ElementOut, ElementAccumulatorPV, typename MainloopFmha::TileShapePV, StrideO,
StrideLSE>,
TileScheduler>>;
using Mainloop = std::conditional_t<kIsMla, MainloopMla, MainloopFmha>;
using Operation = std::conditional_t<kIsMla, OperationMla, OperationFmha>;
//
// Data members
//
/// Initialization
StrideQ stride_Q;
StrideK stride_K;
StrideV stride_V;
StrideO stride_O;
StrideLSE stride_LSE;
template <class ProblemShape>
auto initialize_varlen(const ProblemShape &problem_size, int max_seqlen_q, int max_seqlen_kv,
int total_seqlen_q, int total_seqlen_kv) {
int num_batches = get<3, 1>(problem_size);
ProblemShape problem_size_for_init = problem_size;
get<3, 1>(problem_size_for_init) = 1;
get<0>(problem_size_for_init) = total_seqlen_q;
get<1>(problem_size_for_init) = total_seqlen_kv;
ProblemShapeType problem_size_for_launch;
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
get<2>(problem_size_for_launch) = get<2>(problem_size);
get<3>(problem_size_for_launch) = get<3>(problem_size);
return cute::make_tuple(problem_size_for_init, problem_size_for_launch);
}
template <class Options>
static constexpr auto get_problem_shape(const Options &options) {
int h_r = options.h / options.h_k;
if constexpr (std::is_same_v<Options, MlaOptions>) {
return cute::make_tuple(options.q, options.k, cute::make_tuple(options.dl, options.dr),
cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b));
} else {
return cute::make_tuple(options.q, options.k, options.d,
cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b));
}
}
template <class Options>
ProblemShapeType initialize(const Options &options, int max_seqlen_q, int max_seqlen_kv,
int total_seqlen_q, int total_seqlen_kv,
void *cumulative_length_q, void *cumulative_length_kv) {
assert(options.h % options.h_k == 0);
auto problem_shape_in = get_problem_shape(options);
ProblemShapeType problem_shape;
decltype(problem_shape_in) problem_size;
if constexpr (kIsVarlen) {
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(
problem_shape_in, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv);
problem_shape = problem_shape_launch;
problem_size = problem_shape_init;
} else {
problem_size = problem_shape_in;
problem_shape = problem_shape_in;
}
auto get_head_dimension = [&]() {
if constexpr (rank_v<decltype(get<2>(problem_shape))> == 2) {
return cute::make_tuple(size<2, 0>(problem_shape) + size<2, 1>(problem_shape),
size<2, 0>(problem_shape));
} else {
return cute::make_tuple(size<2>(problem_size), size<2>(problem_size));
}
};
if constexpr (kIsVarlen) {
get<0>(problem_shape).cumulative_length = static_cast<int *>(cumulative_length_q);
get<1>(problem_shape).cumulative_length = static_cast<int *>(cumulative_length_kv);
}
return problem_shape;
}
auto get_arguments(const ProblemShapeType &problem_shape,
const cutlass::KernelHardwareInfo &hw_info, float scale_softmax,
void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr,
void *cumulative_length_q, void *cumulative_length_kv) {
auto problem_shape_ = problem_shape;
typename Operation::Arguments arguments{
problem_shape_,
{static_cast<Element *>(q_ptr), stride_Q, static_cast<Element *>(k_ptr), stride_K,
static_cast<Element *>(v_ptr), stride_V, scale_softmax},
{static_cast<ElementOut *>(o_ptr), stride_O,
static_cast<ElementAccumulatorPV *>(lse_ptr), stride_LSE},
hw_info};
return arguments;
}
template <class Options>
void run(const Options &options, const cutlass::KernelHardwareInfo &hw_info, at::Tensor q,
at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, float scale_softmax,
at::Tensor workspace, at::Tensor cumulative_seqlen_q,
at::Tensor cumulative_seqlen_kv, int max_seqlen_q, int max_seqlen_kv) {
int total_seqlen_q = q.size(0);
int total_seqlen_kv = k.size(0);
ProblemShapeType problem_shape =
initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv,
cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr());
int SQ = size<0>(problem_shape);
int SK = size<1>(problem_shape);
int B = size<3, 1>(problem_shape);
int H = size<3, 0>(problem_shape);
int H_K = size<3, 0, 1>(problem_shape);
int H_Q = size<3, 0, 0>(problem_shape);
int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2);
int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2);
int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2);
int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2);
int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1);
TORCH_CHECK(q_stride2 == 1);
TORCH_CHECK(k_stride2 == 1);
TORCH_CHECK(v_stride2 == 1);
TORCH_CHECK(o_stride2 == 1);
TORCH_CHECK(lse_stride0 == 1);
stride_Q = make_stride(q_stride0, _1{}, make_stride(make_stride(q_stride1, H_Q * q_stride1), SQ * q_stride0));
stride_O = make_stride(o_stride0, _1{}, make_stride(make_stride(o_stride1, H_Q * o_stride1), SQ * o_stride0));
stride_K = make_stride(k_stride0, _1{}, make_stride(make_stride(_0{}, k_stride1), SK * k_stride0));
stride_V = make_stride(v_stride0, _1{}, make_stride(make_stride(_0{}, v_stride1), SK * v_stride0));
stride_LSE = make_stride(_1{}, make_stride(make_stride(lse_stride1, lse_stride1 * H_Q), SQ));
if constexpr (kIsVarlen) {
get<2, 1>(stride_Q) = 0;
get<2, 1>(stride_K) = 0;
get<2, 1>(stride_V) = 0;
get<2, 1>(stride_O) = 0;
get<1, 1>(stride_LSE) = 0;
}
typename Operation::Arguments arguments =
get_arguments(problem_shape, hw_info, scale_softmax, q.data_ptr(), k.data_ptr(),
v.data_ptr(), o.data_ptr(), lse.data_ptr(),
cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr());
Operation op;
// size_t workspace_size = 0;
// workspace_size = Operation::get_workspace_size(arguments);
// todo: if use workspace, need check workspace size first.
// we don't use workspace in current version.
CUTLASS_CHECK(op.can_implement(arguments));
CUTLASS_CHECK(op.initialize(arguments, nullptr));
CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));
}
};
template <class DTypeIn, class DTypeOut, bool kIsVarlen, bool kIsMla, class ActiveMask,
class... KernelOptions>
void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o,
at::Tensor lse, float scale_softmax, int max_seqlen_q, int max_seqlen_kv) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
auto get_options = [&]() {
if constexpr (kIsMla) {
MlaOptions options;
options.b = cumulative_seqlen_q.size(0) - 1;
options.h = q.size(1);
options.h_k = k.size(1);
options.q = q.size(0) / options.b;
options.k = k.size(0) / options.b;
options.dl = v.size(-1);
options.dr = q.size(-1) - v.size(-1);
return options;
} else {
FmhaOptions options;
options.b = cumulative_seqlen_q.size(0) - 1;
options.h = q.size(1);
options.h_k = k.size(1);
options.q = q.size(0) / options.b;
options.k = k.size(0) / options.b;
options.d = q.size(-1);
return options;
}
};
auto options = get_options();
if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 &&
(std::is_same_v<ActiveMask, CausalMask<false>> || std::is_same_v<ActiveMask, CausalMask<true>>)) {
FwdRunner<kIsMla, true, kIsVarlen, DTypeIn, DTypeOut, ActiveMask, KernelOptions...> runner;
runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q,
cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv);
} else {
FwdRunner<kIsMla, false, kIsVarlen, DTypeIn, DTypeOut, ActiveMask, KernelOptions...> runner;
runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q,
cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv);
}
}
#pragma once
#include <ATen/Tensor.h>
void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor o, at::Tensor lse,
int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen);
void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen);
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
// Swizzle Q tile and H tile to improve L2 cache hit rate,
// and launch the longest main loop first to keep most SMs busy.
struct CausalIndividualTileScheduler {
static constexpr int TileQ = 16;
static constexpr int TileH = 8;
static constexpr int TileSize = TileQ * TileH;
struct Params {
dim3 grid;
int tile_max_q;
FastDivmod divmod_tile_col;
FastDivmod divmod_tile_size;
FastDivmod divmod_tile_head;
};
bool valid_ = true;
Params params;
CUTLASS_DEVICE
CausalIndividualTileScheduler(Params const& params) : params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size));
// gridDim.x must multiple of TileH
const int tile_col_count = grid.x / TileH;
const int tile_max_q = grid.y / TileQ * TileQ;
return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH};
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
const int block_idx = blockIdx.y * gridDim.x + blockIdx.x;
int tile_idx, tile_tail;
params.divmod_tile_size(tile_idx, tile_tail, block_idx);
int tile_row_idx, tile_col_idx;
params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx);
int row_offset_in_tail, col_offset_in_tail;
params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail);
const int row_idx = tile_row_idx * TileQ + row_offset_in_tail;
const int col_idx = tile_col_idx * TileH + col_offset_in_tail;
// last q tile launch first
if(blockIdx.y >= params.tile_max_q) {
return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z)));
}
return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z)));
}
CUTLASS_DEVICE
CausalIndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
// Launch order: H Q B
struct CausalPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_h;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);
return Params {
num_blocks,
{ size<3,0>(problem_size) }, { max(1, num_m_blocks) }, { size<3,1>(problem_size) },
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, bidh;
params.divmod_h(block_decode, bidh, block_decode);
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
}
CUTLASS_DEVICE
CausalPersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
#include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
namespace cutlass::fmha::kernel {
using namespace cute;
template<class ProblemShape, class Element, class ElementAcc>
struct FmhaKernelBwdConvert {
struct Arguments {
ProblemShape problem_shape;
const ElementAcc* ptr_src_dQ;
tuple<int, _1, tuple<int, int>> stride_src_dQ;
const ElementAcc* ptr_src_dK;
tuple<int, _1, tuple<int, int>> stride_src_dK;
const ElementAcc* ptr_src_dV;
tuple<int, _1, tuple<int, int>> stride_src_dV;
Element* ptr_dest_dQ;
tuple<int, _1, tuple<int, int>> stride_dest_dQ;
Element* ptr_dest_dK;
tuple<int, _1, tuple<int, int>> stride_dest_dK;
Element* ptr_dest_dV;
tuple<int, _1, tuple<int, int>> stride_dest_dV;
ElementAcc scale = 1.0;
};
using Params = Arguments;
using ClusterShape = Shape<_1, _1, _1>;
static constexpr int SharedStorageSize = 0;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = 128;
using ArchTag = cutlass::arch::Sm90;
static const int kBlockSeq = 8;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static const int kNumThreadsD = 16;
static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;
static const int kElementsPerLoad = 4;
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));
return grid;
}
static dim3 get_block_shape() {
dim3 block(kNumThreadsD, kNumThreadsSeq, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return args;
}
template<class StrideSrc, class StrideDest, class Count>
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) {
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
int seqlen = count;
if constexpr (is_variable_length_v<decltype(count)>) {
int offset = count.cumulative_length[blockIdx.y];
ptr_dest_bh += offset * get<0>(stride_dest);
seqlen = count.cumulative_length[blockIdx.y + 1] - offset;
}
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
if (idx_s >= seqlen) continue;
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) {
ElementAcc value_src[kElementsPerLoad];
Element value_dest[kElementsPerLoad];
using VecSrc = uint_bit_t<sizeof_bits_v<ElementAcc> * kElementsPerLoad>;
using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
*reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) {
value_dest[v] = static_cast<Element>(params.scale * value_src[v]);
}
*reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);
}
}
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if defined(KERUTILS_ENABLE_SM100A)
if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape));
}
if (params.ptr_src_dK != nullptr) {
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape));
}
if (params.ptr_src_dV != nullptr) {
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape));
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n");
}
#endif
}
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
#include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
namespace cutlass::fmha::kernel {
using namespace cute;
template<class ProblemShape, class Element, class ElementAcc>
struct FmhaKernelBwdSumOdO {
struct Arguments {
ProblemShape problem_shape;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
ElementAcc* ptr_sum_OdO;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_sum_OdO;
const ElementAcc* ptr_lse = nullptr;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_lse;
ElementAcc* ptr_scaled_lse = nullptr;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_scaled_lse;
ElementAcc sum_odo_scale = 1.0;
ElementAcc lse_scale = 1.0;
};
using Params = Arguments;
using ClusterShape = Shape<_1, _1, _1>;
static constexpr int SharedStorageSize = 0;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = 128;
using ArchTag = cutlass::arch::Sm100;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static const int kBlockQ = 16;
static const int kNumThreadsD = 8;
static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;
static const int kElementsPerLoad = 2;
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape));
return grid;
}
static dim3 get_block_shape() {
dim3 block(kNumThreadsD, kNumThreadsQ, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return args;
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if defined(KERUTILS_ENABLE_SM100A)
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
auto problem_q = get<0>(params.problem_shape);
int seqlen_q = problem_q;
if constexpr (is_variable_length_v<decltype(problem_q)>) {
int offset = problem_q.cumulative_length[blockIdx.z];
ptr_O_bh += offset * get<0>(params.stride_O);
ptr_dO_bh += offset * get<0>(params.stride_dO);
ptr_lse_bh += offset * get<0>(params.stride_lse);
seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset;
}
CUTLASS_PRAGMA_UNROLL
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
if (idx_q >= seqlen_q) continue;
ElementAcc acc = 0;
auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);
auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO);
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
Element value_O[kElementsPerLoad];
Element value_dO[kElementsPerLoad];
using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
*reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) {
acc += ElementAcc(value_O[v]) * ElementAcc(value_dO[v]);
}
}
for (int i = 1; i < kNumThreadsD; i *= 2) {
acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);
}
if (threadIdx.x == 0) {
*ptr_sum_OdO_bhq = params.sum_odo_scale * acc;
if (params.ptr_scaled_lse) {
*ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq;
}
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n");
}
#endif
}
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
namespace cutlass::fmha::kernel {
template<auto kTag, typename Default, typename... Options>
struct find_option;
template<auto kTag, typename Default>
struct find_option<kTag, Default> {
using option_value = Default;
};
template<auto kTag, typename Default, typename Option, typename... Options>
struct find_option<kTag, Default, Option, Options...> :
std::conditional_t<
Option::tag == kTag,
Option,
find_option<kTag, Default, Options...>
>
{};
template<auto kTag, typename Default, typename... Options>
using find_option_t = typename find_option<kTag, Default, Options...>::option_value;
enum class Tag {
kIsPersistent,
kNumMmaWarpGroups,
kLoadsQSeparately,
kIsMainloopLocked,
kIsEpilogueLocked,
kStagesQ,
kStagesKV,
kEpilogueKind,
kBlocksPerSM,
kClusterM,
kAccQK
};
template<auto kTag, class Value>
struct Option {
static constexpr auto tag = kTag;
using option_value = Value;
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct IndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
IndividualTileScheduler(Params const&) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size));
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
}
CUTLASS_DEVICE
IndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct PersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_h;
FastDivmod divmod_b;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);
return Params {
num_blocks,
{ max(1, num_m_blocks)}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) },
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, bidh;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_h(block_decode, bidh, block_decode);
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
}
CUTLASS_DEVICE
PersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
#include "../collective/fmha_common.hpp"
#include <cmath>
namespace cutlass::fmha::kernel {
using namespace cutlass::fmha::collective;
using namespace cute;
template<
class ProblemShape,
class Element,
class ElementAcc,
class TileShape,
class Mask
>
struct Sm100FmhaBwdKernelTmaWarpSpecialized {
using TileShapeQ = decltype(get<0>(TileShape{}));
static_assert(std::is_same_v<TileShapeQ, _128>, "tile shape K must be 128");
using TileShapeK = decltype(get<1>(TileShape{}));
static_assert(std::is_same_v<TileShapeK, _128>, "tile shape K must be 128");
using TileShapeDQK = decltype(get<2>(TileShape{}));
using TileShapeDVO = decltype(get<2>(TileShape{}));
using TmemAllocator = cute::TMEM::Allocator1Sm;
struct TmemAllocation {
static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc
static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc
static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc
static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp
static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{});
static constexpr uint32_t kP = kS;
static constexpr uint32_t kTotal = kS + TileShapeQ{};
};
static_assert(
static_cast<int>(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns,
"using too much tmem"
);
enum class WarpRole {
Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4
};
static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull;
static constexpr int kNumComputeWarps = 8;
static constexpr int kNumReduceWarps = 4;
CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) {
return static_cast<WarpRole>((kWarpAssignment >> (4 * warp_idx)) & 0xF);
}
struct RegisterAllocation {
static constexpr int kWarpgroup0 = 160-8;
static constexpr int kWarpgroup1 = 128;
static constexpr int kWarpgroup2 = 96;
static constexpr int kReduce = kWarpgroup0;
static constexpr int kCompute = kWarpgroup1;
static constexpr int kMma = kWarpgroup2;
static constexpr int kEmpty = kWarpgroup2;
static constexpr int kLoad = kWarpgroup2;
static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512);
};
using ArchTag = cutlass::arch::Sm100;
using ClusterShape = Shape<_1, _1, _1>;
using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100;
static constexpr int MinBlocksPerMultiprocessor = 1;
static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4;
static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps;
static constexpr int Alignment = 128 / sizeof_bits_v<Element>;
static constexpr int kStages = 2;
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
// compute S
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeQ, TileShapeDQK>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeKQ = typename CollectiveMmaKQ::TileShape;
using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma;
// compute dP
using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeQ, TileShapeDVO>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeVDO = typename CollectiveMmaVDO::TileShape;
using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma;
// compute dV
using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// needs to match ordering of S calculation
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeDVO, TileShapeQ>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapePDO = typename CollectiveMmaPDO::TileShape;
using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{}));
// compute dK
using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// somewhat arbitrary since we dump to smem, need to agree with the next one
Element, TensorStrideContiguousK , Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeDQK, TileShapeQ>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape;
using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma;
// compute dQ
using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// somewhat arbitrary since we dump to smem, need to agree with the previous one
Element, TensorStrideContiguousMN, Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeQ, TileShapeDQK, TileShapeK>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeDSK = typename CollectiveMmaDSK::TileShape;
using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma;
// pipelines are named Pipeline<Producer><Consumer><Resource>
static constexpr int kStagesComputeSmem = 1;
using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>;
using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>;
using PipelineLoadComputeLSE = PipelineAsync<1>;
using PipelineLoadComputeSumOdO = PipelineAsync<1>;
using PipelineMmaComputeS = PipelineUmmaAsync<1>;
using PipelineMmaComputeDP = PipelineUmmaAsync<1>;
using PipelineMmaReduceDQ = PipelineUmmaAsync<1>;
using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>;
using PipelineComputeMmaDS = PipelineUmmaConsumerAsync<kStagesComputeSmem>;
using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>;
static constexpr int kStagesReduceTmaStore = 2;
using PipelineReduceTmaStore = PipelineTmaStore<kStagesReduceTmaStore>;
struct PipelineStorage {
alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q;
alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do;
alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse;
alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo;
alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s;
alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp;
alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq;
alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p;
alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds;
alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv;
};
template<class Layout, class Stages = _1>
static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) {
return composition(layout, make_tuple(_, _, _, make_layout(stages)));
}
using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{}));
using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{}));
using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{}));
using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{}));
using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int<kStagesComputeSmem>{}));
using SmemLayoutLSE = Layout<Shape<TileShapeQ, _1>>;
using SmemLayoutSumOdO = Layout<Shape<TileShapeQ, _1>>;
using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{}));
using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{}));
using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int<kStagesComputeSmem>{}));
using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{}));
using TileShapeDQ = _32;
using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ
>());
using SmemShapeDQ = Shape<TileShapeQ, TileShapeDQ, Int<kStagesReduceTmaStore>>;
using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{}));
struct TensorStorage {
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutK>> smem_k;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKT>> smem_k_t;
};
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutV>> smem_v;
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQT>> smem_q_t;
};
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDO>> smem_do;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDOT>> smem_do_t;
};
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDST>> smem_ds_t;
};
alignas(1024) cute::array<ElementAcc, cute::cosize_v<SmemLayoutDQ>> smem_dq;
alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutLSE>> smem_lse;
alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutSumOdO>> smem_sum_odo;
};
static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
struct SharedStorage {
TensorStorage tensors;
PipelineStorage pipelines;
uint32_t tmem_base_ptr;
};
// this is tight enough that it won't work with sizeof due to padding for alignment
static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
using TensorStride = TensorStrideContiguousK; // S D (H B)
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
struct MainloopArguments {
const Element* ptr_q;
TensorStride stride_q;
const Element* ptr_k;
TensorStride stride_k;
const Element* ptr_v;
TensorStride stride_v;
const Element* ptr_do;
TensorStride stride_do;
const ElementAcc* ptr_lse;
RowTensorStride stride_lse;
const ElementAcc* ptr_sum_odo;
RowTensorStride stride_sum_odo;
ElementAcc* ptr_dq_acc;
TensorStride stride_dq_acc;
ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{});
};
using TMA_K = typename CollectiveMmaKQ::Params::TMA_A;
using TMA_V = typename CollectiveMmaVDO::Params::TMA_A;
using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B;
using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B;
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}),
SmemLayoutDQ{}(_, _, _0{})
));
struct MainloopParams {
TMA_K tma_load_k;
TMA_V tma_load_v;
TMA_Q tma_load_q;
TMA_DO tma_load_do;
TMA_DQ tma_red_dq;
};
struct EpilogueArguments {
Element* ptr_dk;
TensorStride stride_dk;
Element* ptr_dv;
TensorStride stride_dv;
};
struct Arguments {
ProblemShape problem_shape;
MainloopArguments mainloop;
EpilogueArguments epilogue;
KernelHardwareInfo hw_info;
};
struct Params {
ProblemShape problem_shape;
MainloopArguments mainloop;
MainloopParams mainloop_params;
EpilogueArguments epilogue;
KernelHardwareInfo hw_info;
};
static bool can_implement(Arguments const& args) {
auto [Q, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) {
return false;
}
if (D % Alignment != 0 || D_VO % Alignment != 0) {
return false;
}
return true;
}
static Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return Status::kSuccess;
}
static Params to_underlying_arguments(Arguments const& args, void*) {
auto [Q_, K_, D, D_VO, HB] = args.problem_shape;
int Q = Q_;
int K = K_;
if constexpr (is_variable_length_v<decltype(Q_)>) {
Q = Q_.total_length;
}
if constexpr (is_variable_length_v<decltype(K_)>) {
K = K_.total_length;
}
auto params_kq = CollectiveMmaKQ::to_underlying_arguments(
make_shape(K, Q, D, HB),
typename CollectiveMmaKQ::Arguments {
args.mainloop.ptr_k, args.mainloop.stride_k,
args.mainloop.ptr_q, args.mainloop.stride_q,
}, /*workspace=*/nullptr);
auto params_vdo = CollectiveMmaVDO::to_underlying_arguments(
make_shape(K, Q, D_VO, HB),
typename CollectiveMmaVDO::Arguments {
args.mainloop.ptr_v, args.mainloop.stride_v,
args.mainloop.ptr_do, args.mainloop.stride_do,
}, /*workspace=*/nullptr);
TMA_DQ tma_red_dq = make_tma_copy(
SM90_TMA_REDUCE_ADD{},
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc),
SmemLayoutDQ{}(_, _, _0{})
);
return Params{
args.problem_shape,
args.mainloop,
MainloopParams{
params_kq.tma_load_a,
params_vdo.tma_load_a,
params_kq.tma_load_b,
params_vdo.tma_load_b,
tma_red_dq
},
args.epilogue,
args.hw_info
};
}
template<class T>
static CUTLASS_DEVICE auto quantize(T const& input) {
constexpr int AlignmentS = 4;
auto output = make_tensor<Element>(shape(input));
auto input_vec = recast<Array<ElementAcc, AlignmentS>>(input);
auto output_vec = recast<Array<Element, AlignmentS>>(output);
cutlass::NumericArrayConverter<Element, ElementAcc, AlignmentS> epilogue_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(input_vec); i++) {
output_vec(i) = epilogue_op(input_vec(i));
}
return output;
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void load(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state,
PipelineLoadComputeLSE& pipeline_load_compute_lse,
typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state,
PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
using X = Underscore;
uint16_t mcast_mask = 0;
auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB));
auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB));
auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in);
auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in);
auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in);
auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in);
auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step<X, _1, _1>{});
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
auto tSTgK = cta_mma_kq.partition_A(gK);
auto tSTgQ = cta_mma_kq.partition_B(gQ);
auto tDPTgV = cta_mma_vdo.partition_A(gV);
auto tDPTgDO = cta_mma_vdo.partition_B(gDO);
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});
auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});
auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});
auto [tKgK_mkl, tKsK] = tma_partition(
mainloop_params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSTgK));
auto [tQgQ_mkl, tQsQ] = tma_partition(
mainloop_params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ));
auto [tVgV_mkl, tVsV] = tma_partition(
mainloop_params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tDPTgV));
auto [tDOgDO_mkl, tDOsDO] = tma_partition(
mainloop_params.tma_load_do, _0{}, make_layout(_1{}),
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
// set up lse and sum_odo
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK);
// load K
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask),
tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch),
tKsK(_, _0{})
);
}
// load Q
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
tQsQ(_, pipeline_load_mma_q_producer_state.index())
);
}
++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE
// 32 threads loading 128 values of 32b each
// so 4*32b=128b
int thread_idx = threadIdx.x % NumThreadsPerWarp;
int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
// load V
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask),
tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch),
tVsV(_, _0{})
);
}
// load dO
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
tDOsDO(_, pipeline_load_mma_do_producer_state.index())
);
}
++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
iter_count -= 1;
iter_index += 1;
while (iter_count > 0) {
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
// load Q
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
tQsQ(_, pipeline_load_mma_q_producer_state.index())
);
}
++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
// load dO
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
tDOsDO(_, pipeline_load_mma_do_producer_state.index())
);
}
++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
iter_count -= 1;
iter_index += 1;
}
}
template<class BlkCoord, class ProblemShape_>
CUTLASS_DEVICE void mma(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});
auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});
auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});
auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{});
auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{});
auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{});
auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{});
auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{});
auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{});
Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK);
Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ);
Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV);
Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO);
Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS);
Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT);
Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST);
Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT);
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
tDVrP.data() = TmemAllocation::kP;
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
TiledMmaKQ tiled_mma_kq;
TiledMmaVDO tiled_mma_vdo;
TiledMmaDSK tiled_mma_dsk;
TiledMmaDSQ tiled_mma_dsq;
TiledMmaPDO tiled_mma_pdo;
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero;
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero;
Tensor tSTtST = partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{}));
tSTtST.data() = TmemAllocation::kS;
Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{}));
tDPTtDPT.data() = TmemAllocation::kDP;
Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{}));
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{}));
tDKtDK.data() = TmemAllocation::kDK;
Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{}));
tDVtDV.data() = TmemAllocation::kDV;
auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state;
pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);
pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);
// S = Q*K
tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {
cute::gemm(tiled_mma_kq,
tSTrK(_,_,k_block,_0{}),
tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),
tSTtST);
tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One;
}
++pipeline_load_mma_q_consumer_state;
pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);
++pipeline_mma_compute_s_producer_state;
pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);
pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);
pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);
// dP = dO*V
tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {
cute::gemm(tiled_mma_vdo,
tDPTrV(_,_,k_block,_0{}),
tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDPTtDPT);
tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);
++pipeline_mma_compute_dp_producer_state;
pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);
// dV = P*dO
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {
cute::gemm(tiled_mma_pdo,
tDVrP(_,_,k_block),
tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDVtDV);
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);
++pipeline_compute_mma_p_consumer_state;
pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);
++pipeline_load_mma_do_consumer_state;
iter_count -= 1;
// in tmem, S & P overlap
// and dP and dQ overlap
// so we need to acquire dQ and dP at the same time
while (iter_count > 0) {
pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);
pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);
// S = Q*K
tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {
cute::gemm(tiled_mma_kq,
tSTrK(_,_,k_block,_0{}),
tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),
tSTtST);
tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One;
}
++pipeline_load_mma_q_consumer_state;
pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);
++pipeline_mma_compute_s_producer_state;
pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);
// we need to acquire dP here, because tmem dQ == tmem dP
pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);
// dQ = dS*K
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {
cute::gemm(tiled_mma_dsk,
tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDQrKT(_,_,k_block,_0{}),
tDQtDQ);
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);
++pipeline_mma_reduce_dq_producer_state;
// dK = dS*Q
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {
cute::gemm(tiled_mma_dsq,
tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),
tDKtDK);
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);
++pipeline_load_mma_q_release_state;
pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);
++pipeline_compute_mma_ds_consumer_state;
// we grab dq here, because in tmem dq == dp
pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);
pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);
// dP = dO*V
tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {
cute::gemm(tiled_mma_vdo,
tDPTrV(_,_,k_block,_0{}),
tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDPTtDPT);
tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);
++pipeline_mma_compute_dp_producer_state;
pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);
// dV = P*dO
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {
cute::gemm(tiled_mma_pdo,
tDVrP(_,_,k_block),
tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDVtDV);
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);
++pipeline_compute_mma_p_consumer_state;
pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);
++pipeline_load_mma_do_consumer_state;
iter_count -= 1;
}
// signal to the epilogue that dV is ready
pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);
pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);
++pipeline_mma_compute_dkdv_producer_state;
pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);
pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);
// dK = dS*Q
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {
cute::gemm(tiled_mma_dsq,
tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),
tDKtDK);
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;
}
// signal to epilgue that dK is ready
pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);
++pipeline_mma_compute_dkdv_producer_state;
// we've already acquired mma_reduce_dq in the loop
// dQ = dS*K
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {
cute::gemm(tiled_mma_dsk,
tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDQrKT(_,_,k_block,_0{}),
tDQtDQ);
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);
++pipeline_mma_reduce_dq_producer_state;
pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);
++pipeline_load_mma_q_release_state;
pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);
++pipeline_compute_mma_ds_consumer_state;
}
template<class TensorG, class TensorR, class TensorC, class TensorShape>
CUTLASS_DEVICE void store(
TensorG gmem,
TensorR const& regs,
TensorC const& coord,
TensorShape const& tensor_shape) {
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
regs.layout()
);
auto thr_copy = copy_op.get_slice(_0{});
Tensor quantized_regs = quantize(regs);
Tensor tCr = thr_copy.partition_S(quantized_regs);
Tensor tCg = thr_copy.partition_D(gmem);
Tensor tPc = thr_copy.partition_D(preds);
copy_if(copy_op, tPc, tCr, tCg);
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void epilogue_clear(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDK = domain_offset(
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
);
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) {
if (elem_less(cDK(i), select<1,2>(problem_shape))) {
gDK(i) = Element(0);
}
}
for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) {
if (elem_less(cDV(i), select<1,3>(problem_shape))) {
gDV(i) = Element(0);
}
}
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void epilogue(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
tDKtDK.data() = TmemAllocation::kDK;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDK = domain_offset(
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
);
constexpr int kNumWarpgroups = kNumComputeWarps / 4;
int dp_idx = threadIdx.x % 128;
int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;
auto split_wg = [&](auto const& t) {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, make_coord(wg_idx, _));
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, _, make_coord(wg_idx, _));
}
};
auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK);
auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx);
Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK));
Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK));
Tensor tTR_rDK = make_tensor<ElementAcc>(shape(tTR_cDK));
Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK));
auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
tDVtDV.data() = TmemAllocation::kDV;
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV);
auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx);
Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV));
Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV));
Tensor tTR_rDV = make_tensor<ElementAcc>(shape(tTR_cDV));
Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV));
pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);
// load tDVtDV
cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV);
// store tDVgDV
store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape));
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);
++pipeline_mma_compute_dkdv_consumer_state;
pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);
// load tDKtDK
cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rDK); i++) {
tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i);
}
// store tDKgDK
store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape));
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);
++pipeline_mma_compute_dkdv_consumer_state;
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void compute(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
TensorStorage& shared_tensors,
PipelineLoadComputeLSE& pipeline_load_compute_lse,
typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state,
PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state,
PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
// in tmem, S & P overlap
// and dP and dQ overlap
// there are two compute wg's that cooperatively compute softmax
// they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
auto store_op = []() {
if constexpr (sizeof(Element) == 1) {
return SM100_TMEM_STORE_32dp32b4x{};
}
else {
return SM100_TMEM_STORE_32dp32b8x{};
}
}();
Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{});
tSTtST.data() = TmemAllocation::kS;
Tensor tDPTtDPT = partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{});
tDPTtDPT.data() = TmemAllocation::kDP;
Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{}));
Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{}));
constexpr int kNumWarpgroups = kNumComputeWarps / 4;
int dp_idx = threadIdx.x % 128;
int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;
auto tiled_t2r = make_tmem_copy(load_op, tSTtST);
auto thread_t2r = tiled_t2r.get_slice(dp_idx);
auto split_wg = [&](auto const& t) {
if constexpr (decltype(size<1>(t))::value > 1) {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t))));
return p(_, make_coord(wg_idx, _), _);
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t), size<3>(t))));
return p(_, make_coord(wg_idx, _), _, _);
}
}
else {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, make_coord(wg_idx, _));
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, _, make_coord(wg_idx, _));
}
}
};
Tensor tTR_cST_p = thread_t2r.partition_D(cST);
Tensor tTR_cST = split_wg(tTR_cST_p);
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT));
Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{});
Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{});
auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{});
auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST);
tDVrP.data() = TmemAllocation::kP;
auto tiled_r2t = make_tmem_copy(store_op, tDVrP);
auto thread_r2t = tiled_r2t.get_slice(dp_idx);
auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP));
auto tRT_cST_p = thread_r2t.partition_S(tDVcST);
auto tRT_cST = split_wg(tRT_cST_p);
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
int last_iter = iter_count - 1 + iter_index;
CUTLASS_PRAGMA_NO_UNROLL
while (iter_count > 0) {
// wait for S and P
pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state);
pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state);
// wait for LSE
pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state);
auto dispatch_bool = [](bool b, auto fn) {
if (b) {
fn(cute::true_type{});
}
else {
fn(cute::false_type{});
}
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
int kv_left = get<1>(blk_coord) * TileShapeK{};
int kv_right = kv_left + TileShapeK{} - 1;
int q_left = iter_index * TileShapeQ{} + offset;
int q_right = q_left + TileShapeQ{} - 1;
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
}
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
// compute P = softmax(S, LSE)
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
if constexpr (decltype(is_masked_tile)::value) {
Mask{}.apply_mask(tTR_rST, [&](int i) {
auto c_transpose = tTR_cST(i);
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
}, problem_shape);
}
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
float2 softmax_scale_log2_e;
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rST); i += 2) {
float2 acc;
float2 lse;
float2 out;
acc.x = tTR_rST(i);
acc.y = tTR_rST(i + 1);
lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index());
lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index());
cute::fma(out, softmax_scale_log2_e, acc, lse);
tTR_rST(i) = ::exp2f(out.x);
tTR_rST(i+1) = ::exp2f(out.y);
}
auto tRT_rST = quantize(tTR_rST);
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
cutlass::arch::fence_view_async_tmem_load();
cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransformBarrier
).arrive_and_wait();
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
});
// notify for P
cutlass::arch::fence_view_async_tmem_store();
pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state);
++pipeline_compute_mma_p_producer_state;
// release S
pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state);
++pipeline_mma_compute_s_consumer_state;
// release LSE
pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state);
++pipeline_load_compute_lse_consumer_state;
// wait for OdO
pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state);
// wait for dP
pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state);
// wait for dS
// in principle, we could defer waiting for dS, and move in the freeing of dP
// however, that would force us to keep dS in registers longer
pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state);
// compute dS = dsoftmax(P, dP, sum_OdO)
cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rDPT); i += 2) {
float2 st;
st.x = tTR_rST(i);
st.y = tTR_rST(i+1);
float2 dpt;
dpt.x = tTR_rDPT(i);
dpt.y = tTR_rDPT(i+1);
float2 odo;
odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index());
odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index());
float2 dif;
// sum odo is negated during preprocess
cute::add(dif, dpt, odo);
float2 out;
cute::mul(out, dif, st);
tTR_rDPT(i) = out.x;
tTR_rDPT(i+1) = out.y;
}
auto tTR_rDST = quantize(tTR_rDPT);
// release dP
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state);
++pipeline_mma_compute_dp_consumer_state;
Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{})
(_, _, _, pipeline_compute_mma_ds_producer_state.index());
auto thread_layout = make_ordered_layout(
make_shape(_128{}, _128{}),
make_stride(_1{}, _0{})
);
auto sDS_pi = as_position_independent_swizzle_tensor(sDS);
auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p)));
auto sDS_pi_slice = split_wg(sDS_pi_slice_p);
copy_aligned(tTR_rDST, sDS_pi_slice);
// notify for dS
cutlass::arch::fence_view_async_shared();
pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state);
++pipeline_compute_mma_ds_producer_state;
// release OdO
pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state);
++pipeline_load_compute_sum_odo_consumer_state;
iter_count -= 1;
iter_index += 1;
}
epilogue(
blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
);
}
template<class BlkCoord, class ProblemShape_>
CUTLASS_DEVICE void reduce(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
TensorStorage& shared_tensors,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
PipelineReduceTmaStore& pipeline_reduce_tma_store,
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
using X = Underscore;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
// must match TileShapeDQ
auto load_op = SM100_TMEM_LOAD_32dp32b32x{};
auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{});
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
(_, _, _, _0{}, blk_coord_batch);
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{});
int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp);
auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ);
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_cDQ = thread_t2r.partition_D(cDQ);
Tensor tTR_gDQ = thread_t2r.partition_D(gDQ);
Tensor tTR_sDQ = thread_t2r.partition_D(sDQ);
Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ);
auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{});
Tensor tDQsDQ = block_tma.partition_S(sDQ);
Tensor tDQcDQ = block_tma.partition_S(cDQ);
Tensor tDQgDQ = block_tma.partition_D(gDQ);
int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0;
while (iter_count > 0) {
pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state);
Tensor tTR_rDQ = make_tensor<ElementAcc>(shape(tTR_cDQ));
// load dQ from tmem to rmem
cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ);
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state);
++pipeline_mma_reduce_dq_consumer_state;
// we don't have enough smem to dump it all to smem, so we do it in stages
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<2>(tTR_cDQ); i++) {
if (lane_predicate) {
pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state);
}
// wait in all threads for the acquire to complete
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier
).arrive_and_wait();
cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index()));
// wait for the stores to all be visible to the TMA
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier
).arrive_and_wait();
if (lane_predicate) {
// launch tma store
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index));
pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);
}
++pipeline_reduce_tma_store_producer_state;
}
iter_count -= 1;
iter_index += 1;
}
}
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if defined(KERUTILS_ENABLE_SM100A)
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
if (role == WarpRole::Load && lane_predicate) {
prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor());
}
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
int initializing_warp = 0;
typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params;
if (role == WarpRole::Load) {
pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer;
}
if (role == WarpRole::Mma) {
pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer;
}
pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load);
// Also loads K in the first iteration
pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ;
pipeline_load_mma_q_params.initializing_warp = initializing_warp++;
PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params;
if (role == WarpRole::Load) {
pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer;
}
if (role == WarpRole::Mma) {
pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer;
}
pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load);
// Also loads V in the first iteration
pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO;
pipeline_load_mma_do_params.initializing_warp = initializing_warp++;
PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params;
if (role == WarpRole::Load) {
pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer;
}
pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp;
pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;
pipeline_load_compute_lse_params.initializing_warp = initializing_warp++;
PipelineLoadComputeLSE pipeline_load_compute_lse(
shared_storage.pipelines.load_compute_lse,
pipeline_load_compute_lse_params,
/*barrier init*/ cute::true_type{});
typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params;
if (role == WarpRole::Load) {
pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer;
}
pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp;
pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;
pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++;
PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo(
shared_storage.pipelines.load_compute_sum_odo,
pipeline_load_compute_sum_odo_params,
/*barrier init*/ cute::true_type{});
typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer;
}
pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_s_params.initializing_warp = initializing_warp++;
PipelineMmaComputeS pipeline_mma_compute_s(
shared_storage.pipelines.mma_compute_s,
pipeline_mma_compute_s_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer;
}
pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++;
PipelineMmaComputeDP pipeline_mma_compute_dp(
shared_storage.pipelines.mma_compute_dp,
pipeline_mma_compute_dp_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params;
if (role == WarpRole::Mma) {
pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer;
}
if (role == WarpRole::Reduce) {
pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer;
}
pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++;
PipelineMmaReduceDQ pipeline_mma_reduce_dq(
shared_storage.pipelines.mma_reduce_dq,
pipeline_mma_reduce_dq_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params;
if (role == WarpRole::Mma) {
pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer;
}
if (role == WarpRole::Compute) {
pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer;
}
pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_compute_mma_p_params.consumer_arv_count = 1;
pipeline_compute_mma_p_params.initializing_warp = initializing_warp++;
PipelineComputeMmaP pipeline_compute_mma_p(
shared_storage.pipelines.compute_mma_p,
pipeline_compute_mma_p_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params;
if (role == WarpRole::Mma) {
pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer;
}
if (role == WarpRole::Compute) {
pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer;
}
pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_compute_mma_ds_params.consumer_arv_count = 1;
pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++;
PipelineComputeMmaDS pipeline_compute_mma_ds(
shared_storage.pipelines.compute_mma_ds,
pipeline_compute_mma_ds_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer;
}
pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++;
PipelineMmaComputeDKDV pipeline_mma_compute_dkdv(
shared_storage.pipelines.mma_compute_dkdv,
pipeline_mma_compute_dkdv_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
PipelineReduceTmaStore pipeline_reduce_tma_store;
TmemAllocator tmem_allocator;
pipeline_init_arrive_relaxed(size(ClusterShape{}));
pipeline_load_mma_q.init_masks(ClusterShape{});
pipeline_load_mma_do.init_masks(ClusterShape{});
pipeline_mma_compute_s.init_masks(ClusterShape{});
pipeline_mma_compute_dp.init_masks(ClusterShape{});
pipeline_mma_reduce_dq.init_masks(ClusterShape{});
pipeline_compute_mma_p.init_masks(ClusterShape{});
pipeline_compute_mma_ds.init_masks(ClusterShape{});
pipeline_mma_compute_dkdv.init_masks(ClusterShape{});
typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state;
typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state;
typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state;
typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state;
typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state;
typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state;
typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state;
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state<decltype(pipeline_load_compute_sum_odo)>();
auto pipeline_mma_compute_s_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_s)>();
auto pipeline_mma_compute_dp_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dp)>();
auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state<decltype(pipeline_mma_reduce_dq)>();
auto pipeline_compute_mma_p_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_p)>();
auto pipeline_compute_mma_ds_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_ds)>();
auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dkdv)>();
auto pipeline_reduce_tma_store_producer_state = make_producer_start_state<decltype(pipeline_reduce_tma_store)>();
pipeline_init_wait(size(ClusterShape{}));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));
auto [problem_shape, blk_offset] = apply_variable_length_offset(
params.problem_shape,
blk_coord
);
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;
}
iter_count -= iter_start;
if (iter_count <= 0) {
epilogue_clear(
blk_coord,
blk_offset,
problem_shape,
params.mainloop,
params.epilogue
);
return;
}
if (role == WarpRole::Load) {
warpgroup_reg_set<RegisterAllocation::kLoad>();
load(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
);
}
else if (role == WarpRole::Mma) {
warpgroup_reg_set<RegisterAllocation::kMma>();
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
mma(
blk_coord,
problem_shape,
iter_start,
iter_count,
params.mainloop,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state,
pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state,
pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state
);
}
else if (role == WarpRole::Compute) {
warpgroup_reg_set<RegisterAllocation::kCompute>();
compute(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.epilogue,
shared_storage.tensors,
pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state,
pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state,
pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
);
cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
).arrive_and_wait();
if (warp_idx % kNumComputeWarps == 0) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
else if (role == WarpRole::Reduce) {
warpgroup_reg_set<RegisterAllocation::kReduce>();
reduce(
blk_coord,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state,
pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state
);
pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state);
}
else {
warpgroup_reg_set<RegisterAllocation::kEmpty>();
/* no-op */
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n");
}
#endif
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static dim3 get_grid_shape(Params const& params) {
auto [Q, K, D, D_VO, HB] = params.problem_shape;
auto [H, B] = HB;
dim3 grid(ceil_div(K, TileShapeK{}), H, B);
return grid;
}
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
#include "../collective/fmha_common.hpp"
#include <cmath>
namespace cutlass::fmha::kernel {
using namespace cutlass::fmha::collective;
using namespace cute;
template<
class ProblemShape,
class Element,
class ElementAcc,
class TileShape,
class Mask
>
struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
using TileShapeQ = decltype(get<0>(TileShape{}));
using TileShapeK = decltype(get<1>(TileShape{}));
using TileShapeDQK = decltype(get<2>(TileShape{}));
using TileShapeDVO = decltype(get<3>(TileShape{}));
using TmemAllocator = cute::TMEM::Allocator1Sm;
struct TmemAllocation {
static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc
static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc
static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc
static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp
static constexpr uint32_t kS = kDQ + 65536 * 16;
static constexpr uint32_t kP = kS;
static constexpr uint32_t kTotal = kDQ + TileShapeDQK{};
};
static_assert(
static_cast<int>(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns,
"using too much tmem"
);
enum class WarpRole {
Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4
};
static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull;
static constexpr int kNumComputeWarps = 8;
static constexpr int kNumReduceWarps = 4;
static constexpr int kLoadPerThread = TileShapeQ{} / NumThreadsPerWarp;
static_assert(TileShapeQ{} % NumThreadsPerWarp == 0, "TileShapeQ must be divisible by NumThreadsPerWarp");
CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) {
return static_cast<WarpRole>((kWarpAssignment >> (4 * warp_idx)) & 0xF);
}
struct RegisterAllocation {
static constexpr int kWarpgroup0 = 160-8;
static constexpr int kWarpgroup1 = 128;
static constexpr int kWarpgroup2 = 96;
static constexpr int kReduce = kWarpgroup0;
static constexpr int kCompute = kWarpgroup1;
static constexpr int kMma = kWarpgroup2;
static constexpr int kEmpty = kWarpgroup2;
static constexpr int kLoad = kWarpgroup2;
static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512);
};
using ArchTag = cutlass::arch::Sm100;
using ClusterShape = Shape<_1, _1, _1>;
using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100;
static constexpr int MinBlocksPerMultiprocessor = 1;
static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4;
static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps;
static constexpr int Alignment = 128 / sizeof_bits_v<Element>;
static constexpr int kStages = 2;
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
// compute S
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeQ, TileShapeK, TileShapeDQK>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeQK = typename CollectiveMmaQK::TileShape;
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
// compute dP
using CollectiveMmaDOV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeQ, TileShapeK, TileShapeDVO>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeDOV = typename CollectiveMmaDOV::TileShape;
using TiledMmaDOV = typename CollectiveMmaDOV::TiledMma;
// compute dV
using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// needs to match ordering of S calculation
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeDVO, TileShapeQ>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapePDO = typename CollectiveMmaPDO::TileShape;
using TiledMmaPDO = typename CollectiveMmaPDO::TiledMma;
// compute dK
using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// somewhat arbitrary since we dump to smem, need to agree with the next one
Element, TensorStrideContiguousK , Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeDQK, TileShapeQ>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape;
using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma;
// compute dQ
using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// somewhat arbitrary since we dump to smem, need to agree with the previous one
Element, TensorStrideContiguousMN, Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeQ, TileShapeDQK, TileShapeK>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeDSK = typename CollectiveMmaDSK::TileShape;
using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma;
// pipelines are named Pipeline<Producer><Consumer><Resource>
static constexpr int kStagesComputeSmem = 1;
using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>;
using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>;
using PipelineLoadComputeLSE = PipelineAsync<1>;
using PipelineLoadComputeSumOdO = PipelineAsync<1>;
using PipelineMmaComputeS = PipelineUmmaAsync<1>;
using PipelineMmaComputeDP = PipelineUmmaAsync<1>;
using PipelineMmaReduceDQ = PipelineUmmaAsync<1>;
using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>;
using PipelineComputeMmaDS = PipelineUmmaConsumerAsync<kStagesComputeSmem>;
using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>;
static constexpr int kStagesReduceTmaStore = 2;
using PipelineReduceTmaStore = PipelineTmaStore<kStagesReduceTmaStore>;
struct PipelineStorage {
alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q;
alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do;
alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse;
alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo;
alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s;
alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp;
alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq;
alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p;
alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds;
alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv;
};
template<class Layout, class Stages = _1>
static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) {
return composition(layout, make_tuple(_, _, _, make_layout(stages)));
}
using SmemLayoutK = decltype(restage(typename CollectiveMmaQK::SmemLayoutB{}));
using SmemLayoutV = decltype(restage(typename CollectiveMmaDOV::SmemLayoutB{}));
using SmemLayoutQ = decltype(restage(typename CollectiveMmaQK::SmemLayoutA{}, _2{}));
using SmemLayoutDO = decltype(restage(typename CollectiveMmaDOV::SmemLayoutA{}, _1{}));
using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int<kStagesComputeSmem>{}));
using SmemLayoutLSE = Layout<Shape<TileShapeQ, _1>>;
using SmemLayoutSumOdO = Layout<Shape<TileShapeQ, _1>>;
using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{}));
using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{}));
using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int<kStagesComputeSmem>{}));
using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{}));
using SmemLayoutP = decltype(restage(typename CollectiveMmaPDO::SmemLayoutA{}, _1{}));
using SmemLayoutPT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, _1{}));
using TileShapeDQ = _32;
using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ
>());
using SmemShapeDQ = Shape<TileShapeQ, TileShapeDQ, Int<kStagesReduceTmaStore>>;
using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{}));
struct TensorStorage {
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutK>> smem_k;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKT>> smem_k_t;
};
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutV>> smem_v;
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQT>> smem_q_t;
};
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDO>> smem_do;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDOT>> smem_do_t;
};
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDST>> smem_ds_t;
};
union{
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutPT>> smem_p_t;
};
alignas(1024) cute::array<ElementAcc, cute::cosize_v<SmemLayoutDQ>> smem_dq;
alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutLSE>> smem_lse;
alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutSumOdO>> smem_sum_odo;
};
static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
struct SharedStorage {
TensorStorage tensors;
PipelineStorage pipelines;
uint32_t tmem_base_ptr;
};
// this is tight enough that it won't work with sizeof due to padding for alignment
static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
using TensorStride = TensorStrideContiguousK; // S D (H B)
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
struct MainloopArguments {
const Element* ptr_q;
TensorStride stride_q;
const Element* ptr_k;
TensorStride stride_k;
const Element* ptr_v;
TensorStride stride_v;
const Element* ptr_do;
TensorStride stride_do;
const ElementAcc* ptr_lse;
RowTensorStride stride_lse;
const ElementAcc* ptr_sum_odo;
RowTensorStride stride_sum_odo;
ElementAcc* ptr_dq_acc;
TensorStride stride_dq_acc;
ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{});
};
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaDOV::Params::TMA_B;
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_DO = typename CollectiveMmaDOV::Params::TMA_A;
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}),
SmemLayoutDQ{}(_, _, _0{})
));
struct MainloopParams {
TMA_K tma_load_k;
TMA_V tma_load_v;
TMA_Q tma_load_q;
TMA_DO tma_load_do;
TMA_DQ tma_red_dq;
};
struct EpilogueArguments {
Element* ptr_dk;
TensorStride stride_dk;
Element* ptr_dv;
TensorStride stride_dv;
};
struct Arguments {
ProblemShape problem_shape;
MainloopArguments mainloop;
EpilogueArguments epilogue;
KernelHardwareInfo hw_info;
};
struct Params {
ProblemShape problem_shape;
MainloopArguments mainloop;
MainloopParams mainloop_params;
EpilogueArguments epilogue;
KernelHardwareInfo hw_info;
};
static bool can_implement(Arguments const& args) {
auto [Q, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0 || D_VO <= 0) {
return false;
}
if (D % Alignment != 0 || D_VO % Alignment != 0) {
return false;
}
return true;
}
static Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return Status::kSuccess;
}
static Params to_underlying_arguments(Arguments const& args, void*) {
auto [Q_, K_, D, D_VO, HB] = args.problem_shape;
int Q = Q_;
int K = K_;
if constexpr (is_variable_length_v<decltype(Q_)>) {
Q = Q_.total_length;
}
if constexpr (is_variable_length_v<decltype(K_)>) {
K = K_.total_length;
}
auto params_kq = CollectiveMmaQK::to_underlying_arguments(
make_shape(Q, K, D, HB),
typename CollectiveMmaQK::Arguments {
args.mainloop.ptr_q, args.mainloop.stride_q,
args.mainloop.ptr_k, args.mainloop.stride_k,
}, /*workspace=*/nullptr);
auto params_vdo = CollectiveMmaDOV::to_underlying_arguments(
make_shape(Q, K, D_VO, HB),
typename CollectiveMmaDOV::Arguments {
args.mainloop.ptr_do, args.mainloop.stride_do,
args.mainloop.ptr_v, args.mainloop.stride_v,
}, /*workspace=*/nullptr);
TMA_DQ tma_red_dq = make_tma_copy(
SM90_TMA_REDUCE_ADD{},
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc),
SmemLayoutDQ{}(_, _, _0{})
);
return Params{
args.problem_shape,
args.mainloop,
MainloopParams{
params_kq.tma_load_b,
params_vdo.tma_load_b,
params_kq.tma_load_a,
params_vdo.tma_load_a,
tma_red_dq
},
args.epilogue,
args.hw_info
};
}
template<class T>
static CUTLASS_DEVICE auto quantize(T const& input) {
constexpr int AlignmentS = 4;
auto output = make_tensor<Element>(shape(input));
auto input_vec = recast<Array<ElementAcc, AlignmentS>>(input);
auto output_vec = recast<Array<Element, AlignmentS>>(output);
cutlass::NumericArrayConverter<Element, ElementAcc, AlignmentS> epilogue_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(input_vec); i++) {
output_vec(i) = epilogue_op(input_vec(i));
}
return output;
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void load(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state,
PipelineLoadComputeLSE& pipeline_load_compute_lse,
typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state,
PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
using X = Underscore;
uint16_t mcast_mask = 0;
auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB));
auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB));
auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in);
auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in);
auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in);
auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in);
auto gK = local_tile(mK, TileShapeQK{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto gQ = local_tile(mQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gV = local_tile(mV, TileShapeDOV{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto gDO = local_tile(mDO, TileShapeDOV{}, make_coord(_,_,_), Step<_1, X, _1>{});
ThrMMA cta_mma_kq = TiledMmaQK{}.get_slice(_0{});
ThrMMA cta_mma_vdo = TiledMmaDOV{}.get_slice(_0{});
auto tSTgK = cta_mma_kq.partition_B(gK);
auto tSTgQ = cta_mma_kq.partition_A(gQ);
auto tDPTgV = cta_mma_vdo.partition_B(gV);
auto tDPTgDO = cta_mma_vdo.partition_A(gDO);
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});
auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});
auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});
auto [tKgK_mkl, tKsK] = tma_partition(
mainloop_params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSTgK));
auto [tQgQ_mkl, tQsQ] = tma_partition(
mainloop_params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ));
auto [tVgV_mkl, tVsV] = tma_partition(
mainloop_params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tDPTgV));
auto [tDOgDO_mkl, tDOsDO] = tma_partition(
mainloop_params.tma_load_do, _0{}, make_layout(_1{}),
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
// set up lse and sum_odo
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK);
// load K
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask),
tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch),
tKsK(_, _0{})
);
}
// load Q
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
tQsQ(_, pipeline_load_mma_q_producer_state.index())
);
}
++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE
// 32 threads loading kLoadPerThread * 32 values of 32b each
int thread_idx = threadIdx.x % NumThreadsPerWarp;
int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread;
int gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;
auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);
for (int i = 0; i < kLoadPerThread; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
// load V
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask),
tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch),
tVsV(_, _0{})
);
}
// load dO
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
tDOsDO(_, pipeline_load_mma_do_producer_state.index())
);
}
++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;
auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);
for (int i = 0; i < kLoadPerThread; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
iter_count -= 1;
iter_index += 1;
while (iter_count > 0) {
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
// load Q
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
tQsQ(_, pipeline_load_mma_q_producer_state.index())
);
}
++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;
for (int i = 0; i < kLoadPerThread; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
// load dO
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
tDOsDO(_, pipeline_load_mma_do_producer_state.index())
);
}
++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;
for (int i = 0; i < kLoadPerThread; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
iter_count -= 1;
iter_index += 1;
}
}
template<class BlkCoord, class ProblemShape_>
CUTLASS_DEVICE void mma(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});
auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});
auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});
auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{});
auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{});
auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{});
auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{});
auto sP = make_tensor(make_smem_ptr(shared_tensors.smem_p.begin()), SmemLayoutP{});
auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{});
Tensor tSTrK = TiledMmaQK::make_fragment_B(sK);
Tensor tSTrQ = TiledMmaQK::make_fragment_A(sQ);
Tensor tDPTrV = TiledMmaDOV::make_fragment_B(sV);
Tensor tDPTrDO = TiledMmaDOV::make_fragment_A(sDO);
Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS);
Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT);
Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST);
Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT);
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP);
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
TiledMmaQK tiled_mma_qk;
TiledMmaDOV tiled_mma_dov;
TiledMmaDSK tiled_mma_dsk;
TiledMmaDSQ tiled_mma_dsq;
TiledMmaPDO tiled_mma_pdo;
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero;
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero;
Tensor tSTtST = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{}));
tSTtST.data() = TmemAllocation::kS;
Tensor tDPTtDPT = partition_fragment_C(tiled_mma_dov, select<0,1>(TileShapeDOV{}));
tDPTtDPT.data() = TmemAllocation::kDP;
Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{}));
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{}));
tDKtDK.data() = TmemAllocation::kDK;
Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{}));
tDVtDV.data() = TmemAllocation::kDV;
auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state;
pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);
pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);
// S = Q*K
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {
cute::gemm(tiled_mma_qk,
tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),
tSTrK(_,_,k_block,_0{}),
tSTtST);
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One;
}
++pipeline_load_mma_q_consumer_state;
pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);
++pipeline_mma_compute_s_producer_state;
pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);
pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);
pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);
// dP = dO*V
tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {
cute::gemm(tiled_mma_dov,
tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDPTrV(_,_,k_block,_0{}),
tDPTtDPT);
tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);
++pipeline_mma_compute_dp_producer_state;
pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);
// dV = P*dO
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {
cute::gemm(tiled_mma_pdo,
tDVrP(_,_,k_block,_0{}),
tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDVtDV);
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);
++pipeline_compute_mma_p_consumer_state;
pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);
++pipeline_load_mma_do_consumer_state;
iter_count -= 1;
// in tmem, S & P overlap
// and dP and dQ overlap
// so we need to acquire dQ and dP at the same time
while (iter_count > 0) {
pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);
pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);
// S = Q*K
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {
cute::gemm(tiled_mma_qk,
tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),
tSTrK(_,_,k_block,_0{}),
tSTtST);
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One;
}
++pipeline_load_mma_q_consumer_state;
pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);
++pipeline_mma_compute_s_producer_state;
pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);
// we need to acquire dP here, because tmem dQ == tmem dP
pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);
// dQ = dS*K
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {
cute::gemm(tiled_mma_dsk,
tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDQrKT(_,_,k_block,_0{}),
tDQtDQ);
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);
++pipeline_mma_reduce_dq_producer_state;
// dK = dS*Q
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {
cute::gemm(tiled_mma_dsq,
tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),
tDKtDK);
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);
++pipeline_load_mma_q_release_state;
pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);
++pipeline_compute_mma_ds_consumer_state;
// we grab dq here, because in tmem dq == dp
pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);
pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);
// dP = dO*V
tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {
cute::gemm(tiled_mma_dov,
tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDPTrV(_,_,k_block,_0{}),
tDPTtDPT);
tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);
++pipeline_mma_compute_dp_producer_state;
pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);
// dV = P*dO
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {
cute::gemm(tiled_mma_pdo,
tDVrP(_,_,k_block,_0{}),
tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDVtDV);
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);
++pipeline_compute_mma_p_consumer_state;
pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);
++pipeline_load_mma_do_consumer_state;
iter_count -= 1;
}
// signal to the epilogue that dV is ready
pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);
pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);
++pipeline_mma_compute_dkdv_producer_state;
pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);
pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);
// dK = dS*Q
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {
cute::gemm(tiled_mma_dsq,
tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),
tDKtDK);
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;
}
// signal to epilgue that dK is ready
pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);
++pipeline_mma_compute_dkdv_producer_state;
// we've already acquired mma_reduce_dq in the loop
// dQ = dS*K
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {
cute::gemm(tiled_mma_dsk,
tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDQrKT(_,_,k_block,_0{}),
tDQtDQ);
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);
++pipeline_mma_reduce_dq_producer_state;
pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);
++pipeline_load_mma_q_release_state;
pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);
++pipeline_compute_mma_ds_consumer_state;
}
template<class TensorG, class TensorR, class TensorC, class TensorShape>
CUTLASS_DEVICE void store(
TensorG gmem,
TensorR const& regs,
TensorC const& coord,
TensorShape const& tensor_shape) {
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
regs.layout()
);
auto thr_copy = copy_op.get_slice(_0{});
Tensor quantized_regs = quantize(regs);
Tensor tCr = thr_copy.partition_S(quantized_regs);
Tensor tCg = thr_copy.partition_D(gmem);
Tensor tPc = thr_copy.partition_D(preds);
copy_if(copy_op, tPc, tCr, tCg);
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void epilogue_clear(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDK = domain_offset(
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
);
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) {
if (elem_less(cDK(i), select<1,2>(problem_shape))) {
gDK(i) = Element(0);
}
}
for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) {
if (elem_less(cDV(i), select<1,3>(problem_shape))) {
gDV(i) = Element(0);
}
}
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void epilogue(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
tDKtDK.data() = TmemAllocation::kDK;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDK = domain_offset(
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
);
constexpr int kNumWarpgroups = kNumComputeWarps / 4;
int dp_idx = threadIdx.x % 128;
int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;
auto split_wg = [&](auto const& t) {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, make_coord(wg_idx, _));
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, _, make_coord(wg_idx, _));
}
};
auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK);
auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx);
Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK));
Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK));
Tensor tTR_rDK = make_tensor<ElementAcc>(shape(tTR_cDK));
Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK));
auto tDVtDV = partition_fragment_C(TiledMmaPDO{}, select<0,1>(TileShapePDO{}))(make_coord(_,_),_0{},_0{});
tDVtDV.data() = TmemAllocation::kDV;
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV);
auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx);
Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV));
Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV));
Tensor tTR_rDV = make_tensor<ElementAcc>(shape(tTR_cDV));
Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV));
pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);
// load tDVtDV
cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV);
// store tDVgDV
store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape));
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);
++pipeline_mma_compute_dkdv_consumer_state;
pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);
// load tDKtDK
cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rDK); i++) {
tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i);
}
// store tDKgDK
store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape));
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);
++pipeline_mma_compute_dkdv_consumer_state;
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void compute(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
TensorStorage& shared_tensors,
PipelineLoadComputeLSE& pipeline_load_compute_lse,
typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state,
PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state,
PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
// in tmem, S & P overlap
// and dP and dQ overlap
// there are two compute wg's that cooperatively compute softmax
// they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc
auto load_op = SM100_TMEM_LOAD_16dp32b32x{};
Tensor tSTtST = partition_fragment_C(TiledMmaQK{}, select<0,1>(TileShapeQK{}))(make_coord(_,_),_0{},_0{});
tSTtST.data() = TmemAllocation::kS;
Tensor tDPTtDPT = partition_fragment_C(TiledMmaDOV{}, select<0,1>(TileShapeDOV{}))(make_coord(_,_),_0{},_0{});
tDPTtDPT.data() = TmemAllocation::kDP;
Tensor cST = make_identity_tensor(take<0,2>(TileShapeQK{}));
Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeDOV{}));
Tensor cPT = make_identity_tensor(take<0,2>(TileShapeQK{}));
constexpr int kNumWarpgroups = kNumComputeWarps / 4;
int dp_idx = threadIdx.x % 128;
int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;
auto tiled_t2r = make_tmem_copy(load_op, tSTtST);
auto thread_t2r = tiled_t2r.get_slice(dp_idx);
auto split_wg = [&](auto const& t) {
if constexpr (decltype(size<1>(t))::value > 1) {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t))));
return p(_, make_coord(wg_idx, _), _);
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t), size<3>(t))));
return p(_, make_coord(wg_idx, _), _, _);
}
}
else {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, make_coord(wg_idx, _));
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, _, make_coord(wg_idx, _));
}
}
};
Tensor tTR_cST_p = thread_t2r.partition_D(cST);
Tensor tTR_cST = split_wg(tTR_cST_p);
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
Tensor tTR_cPT_p = thread_t2r.partition_D(cPT);
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT));
Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{});
Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{});
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
int last_iter = iter_count - 1 + iter_index;
CUTLASS_PRAGMA_NO_UNROLL
while (iter_count > 0) {
// wait for S and P
pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state);
pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state);
// wait for LSE
pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state);
auto dispatch_bool = [](bool b, auto fn) {
if (b) {
fn(cute::true_type{});
}
else {
fn(cute::false_type{});
}
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
int kv_left = get<1>(blk_coord) * TileShapeK{};
int kv_right = kv_left + TileShapeK{} - 1;
int q_left = iter_index * TileShapeQ{} + offset;
int q_right = q_left + TileShapeQ{} - 1;
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
}
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
// compute P = softmax(S, LSE)
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
if constexpr (decltype(is_masked_tile)::value) {
Mask{}.apply_mask(tTR_rST, [&](int i) {
auto c_transpose = tTR_cST(i);
return make_coord(get<0>(c_transpose) + iter_index * TileShapeQ{}, get<1>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
}, problem_shape);
}
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
float2 softmax_scale_log2_e;
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rST); i += 2) {
float2 acc;
float2 lse;
float2 out;
acc.x = tTR_rST(i);
acc.y = tTR_rST(i + 1);
lse.x = sLSE(get<0>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index());
lse.y = sLSE(get<0>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index());
cute::fma(out, softmax_scale_log2_e, acc, lse);
tTR_rST(i) = ::exp2f(out.x);
tTR_rST(i+1) = ::exp2f(out.y);
}
auto tRT_rST = quantize(tTR_rST);
Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})
(_, _, _, pipeline_compute_mma_p_producer_state.index());
cutlass::arch::fence_view_async_tmem_load();
cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransformBarrier
).arrive_and_wait();
auto sP_pi = as_position_independent_swizzle_tensor(sP);
auto thread_layout = make_ordered_layout(
make_shape(_64{}, _32{}, _2{}, _2{}),
make_stride(_3{}, _0{}, _1{}, _2{})
);
auto sP_pi_slice_p = sP_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape(tTR_cPT_p)));
auto sP_pi_slice = split_wg(sP_pi_slice_p);
copy_aligned(tRT_rST, sP_pi_slice);
});
// notify for P
cutlass::arch::fence_view_async_shared();
pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state);
++pipeline_compute_mma_p_producer_state;
// release S
pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state);
++pipeline_mma_compute_s_consumer_state;
// release LSE
pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state);
++pipeline_load_compute_lse_consumer_state;
// wait for OdO
pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state);
// wait for dP
pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state);
// wait for dS
// in principle, we could defer waiting for dS, and move in the freeing of dP
// however, that would force us to keep dS in registers longer
pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state);
// compute dS = dsoftmax(P, dP, sum_OdO)
cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rDPT); i += 2) {
float2 st;
st.x = tTR_rST(i);
st.y = tTR_rST(i+1);
float2 dpt;
dpt.x = tTR_rDPT(i);
dpt.y = tTR_rDPT(i+1);
float2 odo;
odo.x = sSumOdO(get<0>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index());
odo.y = sSumOdO(get<0>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index());
float2 dif;
// sum odo is negated during preprocess
cute::add(dif, dpt, odo);
float2 out;
cute::mul(out, dif, st);
tTR_rDPT(i) = out.x;
tTR_rDPT(i+1) = out.y;
}
auto tTR_rDST = quantize(tTR_rDPT);
// release dP
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state);
++pipeline_mma_compute_dp_consumer_state;
Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds_t.begin()), SmemLayoutDST{})
(_, _, _, pipeline_compute_mma_ds_producer_state.index());
auto thread_layout = make_ordered_layout(
make_shape(_64{}, _32{}, _2{}, _2{}),
make_stride(_3{}, _0{}, _1{}, _2{})
);
auto sDS_pi = as_position_independent_swizzle_tensor(sDS);
auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape (tTR_cDPT_p)));
auto sDS_pi_slice = split_wg(sDS_pi_slice_p);
copy_aligned(tTR_rDST, sDS_pi_slice);
// notify for dS
cutlass::arch::fence_view_async_shared();
pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state);
++pipeline_compute_mma_ds_producer_state;
// release OdO
pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state);
++pipeline_load_compute_sum_odo_consumer_state;
iter_count -= 1;
iter_index += 1;
}
epilogue(
blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
);
}
template<class BlkCoord, class ProblemShape_>
CUTLASS_DEVICE void reduce(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
TensorStorage& shared_tensors,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
PipelineReduceTmaStore& pipeline_reduce_tma_store,
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
using X = Underscore;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
// must match TileShapeDQ
auto load_op = SM100_TMEM_LOAD_16dp32b16x{};
auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{});
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
auto gDQ = local_tile(mDQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{})
(_, _, _, _0{}, blk_coord_batch);
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{});
int thread_idx = threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp);
auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ);
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_cDQ = thread_t2r.partition_D(cDQ);
Tensor tTR_gDQ = thread_t2r.partition_D(gDQ);
Tensor tTR_sDQ = thread_t2r.partition_D(sDQ);
Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ);
auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{});
Tensor tDQsDQ = block_tma.partition_S(sDQ);
Tensor tDQcDQ = block_tma.partition_S(cDQ);
Tensor tDQgDQ = block_tma.partition_D(gDQ);
int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0;
while (iter_count > 0) {
pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state);
Tensor tTR_rDQ = make_tensor<ElementAcc>(shape(tTR_cDQ));
// load dQ from tmem to rmem
cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ);
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state);
++pipeline_mma_reduce_dq_consumer_state;
// we don't have enough smem to dump it all to smem, so we do it in stages
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<2>(tTR_cDQ); i++) {
if (lane_predicate) {
pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state);
}
// wait in all threads for the acquire to complete
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier
).arrive_and_wait();
cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index()));
// wait for the stores to all be visible to the TMA
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier
).arrive_and_wait();
if (lane_predicate) {
// launch tma store
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index));
pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);
}
++pipeline_reduce_tma_store_producer_state;
}
iter_count -= 1;
iter_index += 1;
}
}
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if defined(KERUTILS_ENABLE_SM100A)
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
if (role == WarpRole::Load && lane_predicate) {
prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor());
}
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
int initializing_warp = 0;
typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params;
if (role == WarpRole::Load) {
pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer;
}
if (role == WarpRole::Mma) {
pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer;
}
pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load);
// Also loads K in the first iteration
pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ;
pipeline_load_mma_q_params.initializing_warp = initializing_warp++;
PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params;
if (role == WarpRole::Load) {
pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer;
}
if (role == WarpRole::Mma) {
pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer;
}
pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load);
// Also loads V in the first iteration
pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO;
pipeline_load_mma_do_params.initializing_warp = initializing_warp++;
PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params;
if (role == WarpRole::Load) {
pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer;
}
pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp;
pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;
pipeline_load_compute_lse_params.initializing_warp = initializing_warp++;
PipelineLoadComputeLSE pipeline_load_compute_lse(
shared_storage.pipelines.load_compute_lse,
pipeline_load_compute_lse_params,
/*barrier init*/ cute::true_type{});
typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params;
if (role == WarpRole::Load) {
pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer;
}
pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp;
pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;
pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++;
PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo(
shared_storage.pipelines.load_compute_sum_odo,
pipeline_load_compute_sum_odo_params,
/*barrier init*/ cute::true_type{});
typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer;
}
pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_s_params.initializing_warp = initializing_warp++;
PipelineMmaComputeS pipeline_mma_compute_s(
shared_storage.pipelines.mma_compute_s,
pipeline_mma_compute_s_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer;
}
pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++;
PipelineMmaComputeDP pipeline_mma_compute_dp(
shared_storage.pipelines.mma_compute_dp,
pipeline_mma_compute_dp_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params;
if (role == WarpRole::Mma) {
pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer;
}
if (role == WarpRole::Reduce) {
pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer;
}
pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++;
PipelineMmaReduceDQ pipeline_mma_reduce_dq(
shared_storage.pipelines.mma_reduce_dq,
pipeline_mma_reduce_dq_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params;
if (role == WarpRole::Mma) {
pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer;
}
if (role == WarpRole::Compute) {
pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer;
}
pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_compute_mma_p_params.consumer_arv_count = 1;
pipeline_compute_mma_p_params.initializing_warp = initializing_warp++;
PipelineComputeMmaP pipeline_compute_mma_p(
shared_storage.pipelines.compute_mma_p,
pipeline_compute_mma_p_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params;
if (role == WarpRole::Mma) {
pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer;
}
if (role == WarpRole::Compute) {
pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer;
}
pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_compute_mma_ds_params.consumer_arv_count = 1;
pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++;
PipelineComputeMmaDS pipeline_compute_mma_ds(
shared_storage.pipelines.compute_mma_ds,
pipeline_compute_mma_ds_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer;
}
pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++;
PipelineMmaComputeDKDV pipeline_mma_compute_dkdv(
shared_storage.pipelines.mma_compute_dkdv,
pipeline_mma_compute_dkdv_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
PipelineReduceTmaStore pipeline_reduce_tma_store;
TmemAllocator tmem_allocator;
pipeline_init_arrive_relaxed(size(ClusterShape{}));
pipeline_load_mma_q.init_masks(ClusterShape{});
pipeline_load_mma_do.init_masks(ClusterShape{});
pipeline_mma_compute_s.init_masks(ClusterShape{});
pipeline_mma_compute_dp.init_masks(ClusterShape{});
pipeline_mma_reduce_dq.init_masks(ClusterShape{});
pipeline_compute_mma_p.init_masks(ClusterShape{});
pipeline_compute_mma_ds.init_masks(ClusterShape{});
pipeline_mma_compute_dkdv.init_masks(ClusterShape{});
typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state;
typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state;
typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state;
typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state;
typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state;
typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state;
typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state;
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state<decltype(pipeline_load_compute_sum_odo)>();
auto pipeline_mma_compute_s_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_s)>();
auto pipeline_mma_compute_dp_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dp)>();
auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state<decltype(pipeline_mma_reduce_dq)>();
auto pipeline_compute_mma_p_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_p)>();
auto pipeline_compute_mma_ds_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_ds)>();
auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dkdv)>();
auto pipeline_reduce_tma_store_producer_state = make_producer_start_state<decltype(pipeline_reduce_tma_store)>();
pipeline_init_wait(size(ClusterShape{}));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));
auto [problem_shape, blk_offset] = apply_variable_length_offset(
params.problem_shape,
blk_coord
);
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;
}
iter_count -= iter_start;
if (iter_count <= 0) {
epilogue_clear(
blk_coord,
blk_offset,
problem_shape,
params.mainloop,
params.epilogue
);
return;
}
if (role == WarpRole::Load) {
warpgroup_reg_set<RegisterAllocation::kLoad>();
load(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
);
}
else if (role == WarpRole::Mma) {
warpgroup_reg_set<RegisterAllocation::kMma>();
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
mma(
blk_coord,
problem_shape,
iter_start,
iter_count,
params.mainloop,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state,
pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state,
pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state
);
}
else if (role == WarpRole::Compute) {
warpgroup_reg_set<RegisterAllocation::kCompute>();
compute(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.epilogue,
shared_storage.tensors,
pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state,
pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state,
pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
);
cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
).arrive_and_wait();
if (warp_idx % kNumComputeWarps == 0) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
else if (role == WarpRole::Reduce) {
warpgroup_reg_set<RegisterAllocation::kReduce>();
reduce(
blk_coord,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state,
pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state
);
pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state);
}
else {
warpgroup_reg_set<RegisterAllocation::kEmpty>();
/* no-op */
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n");
}
#endif
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static dim3 get_grid_shape(Params const& params) {
auto [Q, K, D, D_VO, HB] = params.problem_shape;
auto [H, B] = HB;
dim3 grid(ceil_div(K, TileShapeK{}), H, B);
return grid;
}
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/kernel_hardware_info.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/arch/tmem_allocator_sm100.hpp"
#include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
#include "../kernel/fmha_options.hpp"
#include "../kernel/fmha_tile_scheduler.hpp"
#include "../kernel/fmha_causal_tile_scheduler.hpp"
#include "../collective/fmha_fusion.hpp"
#include "../collective/fmha_common.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
using namespace cutlass::fmha::collective;
struct Sm100FmhaCtxKernelWarpspecializedSchedule {
enum class WarpRole {
Softmax0,
Softmax1,
Correction,
MMA,
Load,
Epilogue,
Empty
};
static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
int wg_idx = warp_idx / 4; // warp_idx
if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3
if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7
if (wg_idx == 2) return WarpRole::Correction; // 8 - 11
if (warp_idx == 12) return WarpRole::MMA; // 12
if (warp_idx == 13) return WarpRole::Load; // 13
if (warp_idx == 14) return WarpRole::Epilogue; // 14
return WarpRole::Empty; // 15
}
static const int NumWarpsSoftmax = 4;
static const int NumWarpsCorrection = 4;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
static const bool kDebugUsingPrintf = false;
static const int NumRegsSoftmax = 192;
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsEmpty = 24;
static const int NumWarps = 16;
};
struct Sm100MlaFwdCtxKernelWarpspecializedSchedule {
enum class WarpRole {
Softmax0,
Softmax1,
Correction,
MMA,
Load,
Epilogue,
Empty
};
static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
int wg_idx = warp_idx / 4; // warp_idx
if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3
if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7
if (wg_idx == 2) return WarpRole::Correction; // 8 - 11
if (warp_idx == 12) return WarpRole::MMA; // 12
if (warp_idx == 13) return WarpRole::Load; // 13
if (warp_idx == 14) return WarpRole::Epilogue; // 14
return WarpRole::Empty; // 15
}
static const int NumWarpsSoftmax = 4;
static const int NumWarpsCorrection = 4;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
static const bool kDebugUsingPrintf = false;
static const int NumRegsSoftmax = 184;
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsEmpty = 24;
static const int NumWarps = 16;
};
template<
class ProblemShapeIn,
class CollectiveMainloop,
class CollectiveEpilogue,
class TileScheduler,
class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule
>
struct Sm100FmhaFwdKernelTmaWarpspecialized {
using TileShape = typename CollectiveMainloop::TileShape;
using ProblemShape = ProblemShapeIn;
using WarpRole = typename KernelSchedule::WarpRole;
constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
return KernelSchedule::warp_idx_to_WarpRole(warp_idx);
}
static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax;
static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection;
static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue;
static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad;
static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue);
static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad);
static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax;
static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection;
static const int NumRegsOther = KernelSchedule::NumRegsOther;
static const int NumRegsEmpty = 24;
static const int NumWarps = KernelSchedule::NumWarps;
static constexpr bool IsMla = std::is_same_v<KernelSchedule, Sm100MlaFwdCtxKernelWarpspecializedSchedule>;
using ClusterShape = typename CollectiveMainloop::ClusterShape;
using TmemAllocator = cute::TMEM::Allocator1Sm;
struct SharedStorage {
using UnionType = union {
typename CollectiveMainloop::TensorStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue;
};
using StructType = struct {
typename CollectiveMainloop::TensorStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue;
};
static constexpr bool IsPersistent = std::is_same_v<TileScheduler, PersistentTileScheduler> || std::is_same_v<TileScheduler, CausalPersistentTileScheduler>;
using MainloopEpilogueStorage = std::conditional_t<IsPersistent,
std::conditional_t<IsMla,
std::conditional_t<CollectiveMainloop::IsOrderLoadEpilogue, UnionType, StructType>,
StructType>,
UnionType>;
MainloopEpilogueStorage mainloop_epilogue;
struct PipelineStorage {
alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q;
alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv;
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0;
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1;
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr;
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr;
alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr;
alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi;
alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01;
} pipelines;
uint32_t tmem_base_ptr;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
struct Arguments {
ProblemShape problem_shape;
typename CollectiveMainloop::Arguments mainloop;
typename CollectiveEpilogue::Arguments epilogue;
cutlass::KernelHardwareInfo hw_info;
};
struct Params {
ProblemShape problem_shape;
typename CollectiveMainloop::Params mainloop;
typename CollectiveEpilogue::Params epilogue;
typename TileScheduler::Params tile_scheduler;
};
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp;
using ArchTag = cutlass::arch::Sm100;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static bool can_implement(Arguments const& args) {
return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
}
static dim3 get_grid_shape(Params const& params) {
return TileScheduler::get_grid_shape(params.tile_scheduler);
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return Params{
args.problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{})
};
}
CUTLASS_DEVICE auto apply_batch(const Params &params, ProblemShape const& problem_shape, int batch_idx) {
return apply_variable_length(params.problem_shape, batch_idx);
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if defined(KERUTILS_ENABLE_SM100A)
TileScheduler tile_scheduler{params.tile_scheduler};
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_WarpRole(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
if (role == WarpRole::Load && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
}
if (role == WarpRole::Epilogue && lane_predicate) {
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
auto get_epilogue_storage = [&]() {
if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {
return reinterpret_cast<typename CollectiveEpilogue::TensorStorage *>(shared_storage.mainloop_epilogue.mainloop.smem_o.data());
} else {
return &shared_storage.mainloop_epilogue.epilogue;
}
};
typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage();
typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params;
if (role == WarpRole::Load) {
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer;
}
if (role == WarpRole::MMA) {
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer;
}
pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load);
pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ;
typename CollectiveMainloop::PipelineQ pipeline_load_q(
shared_storage.pipelines.load_q,
pipeline_load_q_params,
ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{});
typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params;
if (role == WarpRole::Load) {
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer;
}
if (role == WarpRole::MMA) {
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer;
}
pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load);
pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK;
typename CollectiveMainloop::PipelineKV pipeline_load_kv(
shared_storage.pipelines.load_kv,
pipeline_load_kv_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params;
if (role == WarpRole::MMA) {
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
}
if (role == WarpRole::Softmax0) {
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
}
pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineS pipeline_mma_s0(
shared_storage.pipelines.mma_s0,
pipeline_mma_s0_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params;
if (role == WarpRole::MMA) {
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
}
if (role == WarpRole::Softmax1) {
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
}
pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineS pipeline_mma_s1(
shared_storage.pipelines.mma_s1,
pipeline_mma_s1_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params;
if (role == WarpRole::Softmax0) {
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
}
if (role == WarpRole::Correction) {
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
}
pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineC pipeline_s0_corr(
shared_storage.pipelines.s0_corr,
pipeline_s0_corr_params,
/*barrier init*/ cute::true_type{});
typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params;
if (role == WarpRole::Softmax1) {
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
}
if (role == WarpRole::Correction) {
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
}
pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineC pipeline_s1_corr(
shared_storage.pipelines.s1_corr,
pipeline_s1_corr_params,
/*barrier init*/ cute::true_type{});
typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params;
if (role == WarpRole::MMA) {
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer;
}
if (role == WarpRole::Correction) {
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer;
}
pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineO pipeline_mma_corr(
shared_storage.pipelines.mma_corr,
pipeline_mma_corr_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params;
if (role == WarpRole::Correction) {
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer;
}
if (role == WarpRole::Epilogue) {
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;
}
pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineE pipeline_corr_epi(
shared_storage.pipelines.corr_epi,
pipeline_corr_epi_params,
/*barrier init*/ cute::true_type{});
typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01;
params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0;
params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::OrderBarrierSoftmax order_s01(
shared_storage.pipelines.order_s01, params_order_s01);
TmemAllocator tmem_allocator;
__syncthreads();
pipeline_load_q.init_masks(ClusterShape{});
pipeline_load_kv.init_masks(ClusterShape{});
pipeline_mma_s0.init_masks(ClusterShape{});
pipeline_mma_s1.init_masks(ClusterShape{});
pipeline_mma_corr.init_masks(ClusterShape{});
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state;
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineQ>();
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state;
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineKV>();
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state;
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state;
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state;
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state;
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state;
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineE>();
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state;
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
CollectiveMainloop mainloop;
CollectiveEpilogue epilogue{params.epilogue};
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
warpgroup_reg_set<NumRegsSoftmax>();
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto logical_problem_shape = apply_batch(params,
params.problem_shape, get<2,1>(blk_coord));
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
bool is_softmax_0 = role == WarpRole::Softmax0;
mainloop.softmax(
is_softmax_0 ? 0 : 1, blk_coord,
params.mainloop, logical_problem_shape,
is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1,
is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state,
is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr,
is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state,
order_s01
);
}
}
else if (role == WarpRole::Correction) {
cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();
bool has_valid = false;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto logical_problem_shape = apply_batch(params,
params.problem_shape, get<2,1>(blk_coord));
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
continue;
}
has_valid = true;
if (get<1>(logical_problem_shape) == 0) {
mainloop.correction_empty(
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
epilogue_storage,
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
);
continue;
}
mainloop.correction(
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
epilogue_storage,
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
);
}
if constexpr (NumWarpsEpilogue == 0) {
static_assert(NumWarpsCorrection == 1);
if (has_valid) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
}
else if (role == WarpRole::MMA) {
warpgroup_reg_set<NumRegsOther>();
bool allocated = false;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto logical_problem_shape = apply_batch(params,
params.problem_shape, get<2,1>(blk_coord));
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
continue;
}
if (!allocated) {
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
allocated = true;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
mainloop.mma(
blk_coord,
params.mainloop, logical_problem_shape,
shared_storage.mainloop_epilogue.mainloop,
pipeline_load_q, pipeline_load_q_consumer_state,
pipeline_load_kv, pipeline_load_kv_consumer_state,
pipeline_mma_s0, pipeline_mma_s0_producer_state,
pipeline_mma_s1, pipeline_mma_s1_producer_state,
pipeline_mma_corr, pipeline_mma_corr_producer_state
);
}
}
else if (role == WarpRole::Load) {
warpgroup_reg_set<NumRegsOther>();
if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto logical_problem_shape = apply_batch(params,
params.problem_shape, get<2,1>(blk_coord));
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
mainloop.load(
blk_coord, logical_problem_shape,
params.mainloop, params.problem_shape,
shared_storage.mainloop_epilogue.mainloop,
pipeline_load_q, pipeline_load_q_producer_state,
pipeline_load_kv, pipeline_load_kv_producer_state
);
}
}
else if (role == WarpRole::Epilogue) {
warpgroup_reg_set<NumRegsOther>();
bool has_valid = false;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto logical_problem_shape = apply_batch(params,
params.problem_shape, get<2,1>(blk_coord));
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
continue;
}
has_valid = true;
epilogue.store(
blk_coord, logical_problem_shape,
params.epilogue, params.problem_shape,
epilogue_storage,
pipeline_corr_epi, pipeline_corr_epi_consumer_state
);
}
static_assert(NumWarpsEpilogue <= 1);
if constexpr (NumWarpsEpilogue == 1) {
if(has_valid) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
}
else if (role == WarpRole::Empty) {
warpgroup_reg_set<NumRegsEmpty>();
/* no-op, donate regs and exit */
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n");
}
#endif
}
};
} // namespace cutlass::fmha::kernel
#pragma once
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
namespace sm100 {
/*
Load K/V indices from global memory, and generate validity mask
Each thread loads 8 indices
Should be called by lanes 0 ~ (BLOCK_TOPK/8)
*/
CUTE_DEVICE
char load_indices_and_generate_mask(
int lane_idx,
int* gIndices,
int s_kv,
int abs_pos_start,
int topk_length
) {
int indices[8];
KU_LDG_256(
gIndices + lane_idx*8,
indices,
".nc",
"no_allocate",
"evict_normal",
"256B"
);
auto is_valid = [&](int rel_pos_in_lane, int index) -> char {
int abs_pos = abs_pos_start + lane_idx*8 + rel_pos_in_lane;
return index >= 0 && index < s_kv && abs_pos < topk_length;
};
char is_ks_valid_mask = \
is_valid(7, indices[7]) << 7 |
is_valid(6, indices[6]) << 6 |
is_valid(5, indices[5]) << 5 |
is_valid(4, indices[4]) << 4 |
is_valid(3, indices[3]) << 3 |
is_valid(2, indices[2]) << 2 |
is_valid(1, indices[1]) << 1 |
is_valid(0, indices[0]) << 0;
return is_ks_valid_mask;
}
/*
Get P from Tensor Memory, reduce P within shared memory, perform masking, and store back if necessary
Initially, since dual gemm is used, we have two P pieces in Tensor Memory, one occupying rows 0 ~ 63 while the other occupying rows 64 ~ 127. We'd like to have them reduced into one single P piece, stored in registers with layout:
N N --- (topk)
+-------+-------+
| | |
32 | Warp0 | Warp2 |
| | |
+-------+-------+
| | |
32 | Warp1 | Warp3 |
| | |
+-------+-------+
|
(head)
where N = NUM_ELEMS_PER_THREAD
*/
template<
int NUM_ELEMS_PER_THREAD,
int TMEM_COL_START,
int BARRIER_WARP02_SYNC_ID,
int BARRIER_WARP13_SYNC_ID,
bool STORE_BACK_P
>
CUTE_DEVICE
void retrieve_mask_and_reduce_p(
char* k_validness_base,
int local_warp_idx,
int lane_idx,
auto slot_bar_P_empty_arrival,
float p_exchange_buf[4][32*NUM_ELEMS_PER_THREAD],
float p[NUM_ELEMS_PER_THREAD]
) {
using namespace cute;
using cutlass::arch::NamedBarrier;
static_assert(BARRIER_WARP13_SYNC_ID == BARRIER_WARP02_SYNC_ID+1);
float p_peer[NUM_ELEMS_PER_THREAD];
if (local_warp_idx < 2) {
ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START, p);
ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START + NUM_ELEMS_PER_THREAD, p_peer);
} else {
ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START, p_peer);
ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START + NUM_ELEMS_PER_THREAD, p);
}
cutlass::arch::fence_view_async_tmem_load();
ku::tcgen05_before_thread_sync();
slot_bar_P_empty_arrival();
// Mask invalid tokens
// We put masking before reduction, since (-inf) + anything (except nan and +inf) is (-inf), which guarantees correctness, and this can overlap with smem load
static_assert(NUM_ELEMS_PER_THREAD == 32);
uint32_t is_k_valid = *(uint32_t*)(k_validness_base + (local_warp_idx>=2?NUM_ELEMS_PER_THREAD/8:0));
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD; i += 1) {
if (!(is_k_valid >> i & 1))
p[i] = -CUDART_INF_F;
}
// Reduce P within the cluster
{
// Store
// Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) {
ku::st_shared(&p_exchange_buf[local_warp_idx^2][i*32*4 + lane_idx*4], *(float4*)(p_peer + i*4));
}
NamedBarrier::arrive_and_wait(64, BARRIER_WARP02_SYNC_ID + (local_warp_idx&1));
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) {
float2 t[2];
*(float4*)t = *(float4*)(&p_exchange_buf[local_warp_idx][i*32*4 + lane_idx*4]);
float2* cur_p = (float2*)(p + i*4);
cur_p[0] = ku::float2_add(cur_p[0], t[0]);
cur_p[1] = ku::float2_add(cur_p[1], t[1]);
}
}
if constexpr (STORE_BACK_P) {
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) {
ku::st_shared(&p_exchange_buf[local_warp_idx][i*32*4 + lane_idx*4], *(float4*)(p+i*4));
}
}
}
/*
Rescale O in Tensor Memory.
O should occupy 128 rows x (D_V/2) columns in Tensor Memory.
*/
template<
int D_V,
int CHUNK_SIZE,
int TMEM_COL_START
>
CUTE_DEVICE
void rescale_O(
float scale_factor
) {
float2 scale_factor_float2 = {scale_factor, scale_factor};
float2 o[CHUNK_SIZE/2];
CUTE_UNROLL
for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) {
// Load O
ku::tmem_ld_32dp32bNx<CHUNK_SIZE>(TMEM_COL_START + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_load();
// Mult
for (int i = 0; i < CHUNK_SIZE/2; ++i) {
o[i] = ku::float2_mul(o[i], scale_factor_float2);
}
// Store O
ku::tmem_st_32dp32bNx<CHUNK_SIZE>(TMEM_COL_START + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_store();
}
}
template<int NUM_ELEMS_PER_THREAD>
CUTE_DEVICE
float get_max(
float p[NUM_ELEMS_PER_THREAD]
) {
float local_max = -CUDART_INF_F;
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD; ++i) {
local_max = max(local_max, p[i]);
}
return local_max;
}
/*
Calculate s := exp2f(p*scale - new_max) and its sum
*/
template<int NUM_ELEMS_PER_THREAD>
CUTE_DEVICE
float get_s_from_p(
nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2],
float p[NUM_ELEMS_PER_THREAD],
float scale,
float new_max
) {
float2 cur_sum = float2 {0.0f, 0.0f};
float2 neg_new_max_float2 = float2 {-new_max, -new_max};
float2 scale_float2 = float2 {scale, scale};
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/2; i += 1) {
float2 d = ku::float2_fma(float2{p[i*2], p[i*2+1]}, scale_float2, neg_new_max_float2);
d.x = exp2f(d.x);
d.y = exp2f(d.y);
cur_sum = ku::float2_add(cur_sum, d);
s[i] = __float22bfloat162_rn(d);
}
return cur_sum.x + cur_sum.y;
}
}
#pragma once
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "params.h"
#include "defines.h"
namespace sm100::fwd::head128 {
using namespace cute;
template<
typename Shape_Q, typename TMA_Q,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_kv;
};
struct float2x2 {
float2 lo, hi;
};
template<int D_QK>
struct KernelTemplate {
static constexpr int D_Q = D_QK;
static constexpr int D_K = D_QK;
static constexpr int D_V = 512;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan
static constexpr int B_H = 128; // For 2 CTAs
static constexpr int B_TOPK = 128; // For 2 CTAs
static constexpr int NUM_BUFS = 2;
static constexpr int NUM_THREADS = 256 + 128 + 128; // 128 scale & exp threads, 128x2 TMA threads, 32 UTCMMA threads
static constexpr int D_tQ = 384, NUM_tQ_TILES = D_tQ / 64;
static constexpr int D_sQ = D_QK-D_tQ, NUM_sQ_TILES = D_sQ / 64;
static_assert(D_sQ%64 == 0 && D_tQ%64 == 0 && D_sQ + D_tQ == D_Q);
// Tensor memory columns
struct tmem_cols {
// 0 ~ 256: output
// 256 ~ 320: P
// 320 ~ 512: Q[D_QK-D_tQ:]
static constexpr int o = 0;
static constexpr int p = 256;
static constexpr int q = 512 - D_tQ/2;
static_assert(p+64 <= q);
};
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutO = SmemLayoutOTiles<8>;
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_TOPK/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutV = decltype(coalesce(tile_to_shape(
UMMA::Layout_MN_SW128_Atom<bf16>{},
Shape<Int<256>, Int<B_TOPK>>{},
Step<_2, _1>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutSTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
struct SharedMemoryPlan {
union {
array_aligned<bf16, cosize_v<SmemLayoutQTiles<D_Q/64>>> q_full;
struct {
array_aligned<bf16, cosize_v<SmemLayoutQTiles<NUM_sQ_TILES>>> sq;
array_aligned<bf16, cosize_v<SmemLayoutV>> v;
// NOTE K is not overlapped with q_full, so we can do k copy-in while performing S->T copy for q
static_assert(cosize_v<SmemLayoutQTiles<D_Q/64>> <= cosize_v<SmemLayoutQTiles<NUM_sQ_TILES>> + cosize_v<SmemLayoutV>);
array_aligned<bf16, cosize_v<SmemLayoutKTiles<D_K/64>>> k;
} s;
array_aligned<bf16, cosize_v<SmemLayoutO>> o;
} u;
array_aligned<bf16, cosize_v<SmemLayoutSTiles<2>>> s;
float p[(B_H/2)*B_TOPK];
char is_k_valid[NUM_BUFS][B_TOPK/8];
transac_bar_t bar_prologue_q, bar_prologue_utccp;
transac_bar_t bar_qk_part_done[NUM_BUFS], bar_qk_done[NUM_BUFS]; // Pi = QKi^T done (i.e. Ki free)
transac_bar_t bar_sv_part_done[NUM_BUFS], bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. Vi free)
transac_bar_t bar_k_part0_ready[NUM_BUFS], bar_k_part1_ready[NUM_BUFS];
transac_bar_t bar_v_part0_ready[NUM_BUFS], bar_v_part1_ready[NUM_BUFS]; // Vi is ready
transac_bar_t bar_p_free[NUM_BUFS];
transac_bar_t bar_so_ready[NUM_BUFS]; // S and O are ready
transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS];
array_aligned<uint32_t, 1> tmem_start_addr;
float rowwise_max_buf[128], rowwise_li_buf[128];
};
using TiledMMA_P_tQ = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{}
));
using TiledMMA_P_sQ = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{}
));
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{},
Tile<Int<128>, Layout<Shape<_128, _2, _2>, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
));
template<typename TmaParams>
static __device__ void
sparse_attn_fwd_kernel_devfunc(const SparseAttnFwdParams &params, const TmaParams &tma_params);
};
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd::head128 {
template void run_fwd_phase1_kernel<512>(const SparseAttnFwdParams& params);
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd::head128 {
template void run_fwd_phase1_kernel<576>(const SparseAttnFwdParams& params);
}
#pragma once
#include "phase1.h"
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <cutlass/cluster_launch.hpp>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/arch/arch.h>
#include <cutlass/cuda_host_adapter.hpp>
#include "params.h"
#include "utils.h"
#include "sm100/helpers.h"
#include "config.h"
namespace sm100::fwd::head128 {
using namespace cute;
CUTE_DEVICE int32x8_t ldg_256_indices(void* src_ptr) {
int32x8_t val;
asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];"
: "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3),
"=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7)
: "l"(src_ptr)
);
return val;
}
/*
Pipeline Overview:
| Copy | MMA | Scale & Exp |
K0
V0
P0 = QK0^T
K1 S0 = exp(P0)
scale(O) w.r.t P0
P1 = QK1^T
K2 S1 = exp(P1)
O += S0V0
V1 scale(O) w.r.t P1
P2 = QK2^T
K3 S2 = exp(P2)
O += S1V1
V2 scale(O) w.r.t P2
P3 = QK3^T
K4 S3 = exp(P3)
O += S2V2
V3 scale(O) w.r.t P3
...
O += S(n-3)V(n-3)
V(n-2) scale(O) w.r.t P(n-2)
P(n-1) = QK(n-1)^T
S(n-1) = exp(P(n-1))
O += S(n-2)V(n-2)
V(n-1) scale(O) w.r.t P(n-1)
O += S(n-1)V(n-1)
*/
template<int D_QK>
template<typename TmaParams>
__device__ void
KernelTemplate<D_QK>::sparse_attn_fwd_kernel_devfunc(const SparseAttnFwdParams &params, const TmaParams &tma_params) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
const int cta_idx = blockIdx.x % 2;
const int s_q_idx = blockIdx.x / 2;
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int lane_idx = threadIdx.x % 32;
const int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + s_q_idx) : params.topk;
const int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1
const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const int idx_in_warpgroup = threadIdx.x % 128;
// Prefetch TMA descriptors
if (threadIdx.x == 0) {
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv));
}
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ_full = make_tensor(make_smem_ptr(plan.u.q_full.data()), SmemLayoutQTiles<D_Q/64>{});
int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk]
// Allocate tmem tensors
TiledMMA tiled_mma_P_tQ = TiledMMA_P_tQ{};
TiledMMA tiled_mma_P_sQ = TiledMMA_P_sQ{};
TiledMMA tiled_mma_O = TiledMMA_O{};
Tensor tP = partition_fragment_C(tiled_mma_P_tQ, Shape<Int<B_H/2>, Int<B_TOPK>>{});
Tensor tQr = tiled_mma_P_tQ.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P_tQ, Shape<Int<B_H/2>, Int<D_tQ>>{})
);
Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<B_H/2>, Int<D_V>>{});
tP.data().get() = tmem_cols::p;
tQr.data().get() = tmem_cols::q;
tO.data().get() = tmem_cols::o;
if (warp_idx == 0) {
if (elect_one_sync()) {
// Initialize barriers
plan.bar_prologue_q.init(1);
plan.bar_prologue_utccp.init(1);
CUTE_UNROLL
for (int i = 0; i < NUM_BUFS; ++i) {
plan.bar_qk_part_done[i].init(1);
plan.bar_qk_done[i].init(1);
plan.bar_sv_part_done[i].init(1);
plan.bar_sv_done[i].init(1);
plan.bar_k_part0_ready[i].init(1);
plan.bar_k_part1_ready[i].init(1);
plan.bar_v_part0_ready[i].init(1);
plan.bar_v_part1_ready[i].init(1);
plan.bar_p_free[i].init(128*2);
plan.bar_so_ready[i].init(128*2);
plan.bar_k_valid_ready[i].init(16);
plan.bar_k_valid_free[i].init(128);
}
fence_barrier_init();
}
}
cute::cluster_sync(); // We must add a cluster_sync() here, or TMA from CTA1 may launch before barrier initialization in CTA0
if (warp_idx == 0) {
if (elect_one_sync()) {
// Copy Q
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx),
Tile<Int<B_H/2>>{}
)(_, cta_idx, _);
ku::launch_tma_copy(tma_params.tma_Q, gQ, sQ_full, plan.bar_prologue_q, TMA::CacheHintSm90::EVICT_FIRST);
}
// Initialize TMEM
cute::TMEM::Allocator2Sm().allocate(512, plan.tmem_start_addr.data());
TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator2Sm().release_allocation_lock();
}
__syncthreads(); // Wait for TMEM allocation
if (warpgroup_idx == 0) {
cutlass::arch::warpgroup_reg_alloc<144>();
// Scale & Exp warps
// The following three numbers are
// - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V)
// - li: sumexp, i.e. li := sum(exp(Pi*scale - mi))
// - real_mi: real max logits, i.e. real_mi := max(Pi*scale)
// where Pi is the i-th row of P, P := QK^T
// mi and real_mi are always consistent within the two threads that
// controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update
float mi = MAX_INIT_VAL;
float li = 0.0f;
float real_mi = -CUDART_INF_F;
const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2};
uint128_t* sS_base = (uint128_t*)plan.s.data() + idx_in_warpgroup%64 + 64*((idx_in_warpgroup/64)*8);
float* sP_base = plan.p + idx_in_warpgroup%64*4 + (idx_in_warpgroup/64)*((B_H/2)*(B_TOPK/2));
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
// Wait for P
plan.bar_qk_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1);
ku::tcgen05_after_thread_sync();
// Load P
float2 p[(B_TOPK/2)/2];
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::p, p);
cutlass::arch::fence_view_async_tmem_load();
ku::tcgen05_before_thread_sync();
plan.bar_p_free[k%NUM_BUFS].arrive(0u);
// Mask
plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1);
// The following code enables NVCC to use R2P instruction
// Although we perform 2x LDS.32 instructions here, don't worry, NVCC will
// convert them to one LDS.64 instruction. However, if we write LDS.64
// here, NVCC won't use R2P.
uint32_t is_k_valid_lo = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0));
uint32_t is_k_valid_hi = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0) + 4);
float* p_float = (float*)p;
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; i += 1) {
if (!(is_k_valid_lo >> i & 1))
p_float[i] = -CUDART_INF_F;
}
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; i += 1) {
if (!(is_k_valid_hi >> i & 1))
p_float[i+(B_TOPK/2)/2] = -CUDART_INF_F;
}
// Get rowwise max of Pi
float cur_pi_max = -CUDART_INF_F;
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2); i += 1) {
cur_pi_max = max(cur_pi_max, p_float[i]);
}
cur_pi_max *= params.sm_scale_div_log2;
plan.bar_k_valid_free[k%NUM_BUFS].arrive();
NamedBarrier::arrive_and_wait(128, 0); // Wait for rowwise_max_buf and sP to be ready
plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;
NamedBarrier::arrive_and_wait(128, 0); // TODO Name these barriers
cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]);
real_mi = max(real_mi, cur_pi_max);
bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);
// By this point:
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among threads 0~31+64~95; and is identical among threads 32~63+96~127
// Calc scale factor, and scale li
float new_max, scale_for_old;
if (!should_scale_o) {
// Don't scale O
scale_for_old = 1.0f;
new_max = mi;
} else {
new_max = max(cur_pi_max, mi);
scale_for_old = exp2f(mi - new_max);
}
mi = new_max; // mi is still identical within each row
li *= scale_for_old;
// Calculate S
__nv_bfloat162 s[(B_TOPK/2)/2];
float2 neg_new_max = float2 {-new_max, -new_max};
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; i += 1) {
float2 d = ku::float2_fma(p[i], scale, neg_new_max);
d.x = exp2f(d.x);
d.y = exp2f(d.y);
li += d.x + d.y; // NOTE: Theoretically we could use FFMA2 here but actually this is faster...
s[i] = __float22bfloat162_rn(d);
}
// Wait for last SV gemm, write S
if (k > 0) {
plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
CUTE_UNROLL
for (int i = 0; i < B_TOPK/2/8; i += 1) {
sS_base[64*i] = *(uint128_t*)(s + i*4);
}
// Scale O
if (k > 0 && should_scale_o) {
float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old};
// plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE: We have waited for last SV gemm before
ku::tcgen05_after_thread_sync();
static constexpr int CHUNK_SIZE = 32;
float2 o[CHUNK_SIZE/2];
CUTE_UNROLL
for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) {
// Load O
ku::tmem_ld_32dp32bNx<CHUNK_SIZE>(tmem_cols::o + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_load();
// Mult
for (int i = 0; i < CHUNK_SIZE/2; ++i) {
o[i] = ku::float2_mul(o[i], scale_for_old_float2);
}
// Store O
ku::tmem_st_32dp32bNx<CHUNK_SIZE>(tmem_cols::o + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_store();
}
ku::tcgen05_before_thread_sync();
}
fence_view_async_shared();
plan.bar_so_ready[k%NUM_BUFS].arrive(0u);
}
// Epilogue
if (real_mi == -CUDART_INF_F) {
// real_mi == -CUDART_INF_F <=> No valid TopK indices
// We set li to 0 to fit the definition that li := exp(x[i] - mi)
li = 0.0f;
mi = -CUDART_INF_F;
}
// Exchange li
plan.rowwise_li_buf[idx_in_warpgroup] = li;
NamedBarrier::arrive_and_wait(128, 0);
li += plan.rowwise_li_buf[idx_in_warpgroup^64];
// Store mi and li
if (idx_in_warpgroup < 64) {
int global_index = s_q_idx*params.h_q + cta_idx*(B_H/2) + idx_in_warpgroup;
float cur_lse = logf(li) + mi*CUDART_LN2_F;
cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;
params.max_logits[global_index] = real_mi*CUDART_LN2_F;
params.lse[global_index] = cur_lse;
}
// Wait for the last GEMM
plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1);
ku::tcgen05_after_thread_sync();
// Store O
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + cta_idx*B_H/2 + (idx_in_warpgroup%64))*CUDART_L2E_F;
float output_scale = __fdividef(1.0f, li + exp2f(attn_sink - mi));
Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{});
constexpr int B_EPI = 64;
Tensor tma_gO = flat_divide(
tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx),
Shape<Int<B_H/2>, Int<B_EPI>>{}
)(_, _, cta_idx, _);
Tensor sO_divided = flat_divide(
sO,
Shape<Int<B_H/2>, Int<B_EPI>>{}
)(_, _, _0{}, _);
auto thr_tma = tma_params.tma_O.get_slice(_0{});
float2 o[B_EPI/2];
bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld
if (!have_valid_indices) {
// If there are no valid indices, we set o[i] to 0 and don't load from TMEM
CUTE_UNROLL
for (int i = 0; i < B_EPI/2; ++i)
o[i].x = o[i].y = 0.0f;
output_scale = 1.0f;
}
float2 output_scale_float2 = make_float2(output_scale, output_scale);
CUTE_UNROLL
for (int k = 0; k < (D_V/2)/B_EPI; ++k) {
// Load O from tO
if (have_valid_indices) {
ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::o + k*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
}
// Convert and store
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i) {
__nv_bfloat162 o_bf16[4];
CUTE_UNROLL
for (int j = 0; j < 4; ++j) {
float2 d = ku::float2_mul(o[i*4+j], output_scale_float2);
o_bf16[j] = __float22bfloat162_rn(d);
}
int smem_row = idx_in_warpgroup % 64;
int smem_col = (idx_in_warpgroup/64)*(D_V/2) + k*B_EPI + i*8;
*(uint128_t*)(&sO(smem_row, smem_col)) = *(uint128_t*)(o_bf16);
}
// Sync
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, 0);
if (warp_idx == 0 && elect_one_sync()) {
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sO_divided(_, _, k)),
thr_tma.partition_D(tma_gO(_, _, k))
);
}
if (warp_idx == 1 && elect_one_sync()) {
int k2 = k + (D_V/B_EPI/2);
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sO_divided(_, _, k2)),
thr_tma.partition_D(tma_gO(_, _, k2))
);
}
}
if (warp_idx == 0) {
cute::TMEM::Allocator2Sm().free(0, 512);
}
} else if (warpgroup_idx == 1) {
// Producer warp for K
cutlass::arch::warpgroup_reg_dealloc<96>();
int warp_idx = cutlass::canonical_warp_idx_sync() - 4;
constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/2)/4/NUM_WARPS;
if (elect_one_sync()) {
bf16* sK_base = plan.u.s.k.data() + warp_idx*4*64;
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
int4 indices[NUM_LOCAL_ROWS_PER_WARP];
int max_indices = -1, min_indices = params.s_kv;
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {
indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK + cta_idx*(B_TOPK/2)) + local_row*NUM_WARPS + warp_idx);
max_indices = max(max_indices, int4_max(indices[local_row]));
min_indices = min(min_indices, int4_min(indices[local_row]));
}
bool is_all_rows_invalid = min_indices == params.s_kv || max_indices == -1;
bool should_skip_tma = is_all_rows_invalid && k >= NUM_BUFS;
auto load_part_ki = [&](transac_bar_t &bar, int local_col_start, int local_col_end) {
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {
CUTE_UNROLL
for (int local_col = local_col_start; local_col < local_col_end; ++local_col)
ku::tma_gather4_cta_group_2<true>(
&(tma_params.tensor_map_kv),
bar,
sK_base + local_row*(4*NUM_WARPS)*64 + local_col*((B_TOPK/2)*64),
local_col*64,
indices[local_row],
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
};
int cur_buf = k%NUM_BUFS;
if (k > 0) {
plan.bar_qk_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
if (!should_skip_tma) {
load_part_ki(plan.bar_k_part0_ready[cur_buf], 0, D_sQ/64);
} else {
// NOTE: TMA has performance issues when all indices are the same (even if those indices are invalid), so we detect whether all indices in our block are invalid (by inspecting their MIN and MAX, for performance reasons), and skip the copy if all indices are invalid.
// NOTE: We can also skip the initial zero-fill procedure (which prevents NaN from appearing in K/V buf if the first TMA copy is skipped) by disabling skipping on the first NUM_BUFS TMAs.
// NOTE: We only do this for K to save some checking overhead, since after doing this for K, cases where topk indices are all invalid are faster than the other cases
plan.bar_k_part0_ready[cur_buf].complete_transaction(0u, NUM_LOCAL_ROWS_PER_WARP*4*D_sQ*sizeof(bf16), 1u);
}
if (k > 0) {
plan.bar_qk_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
if (!should_skip_tma) {
load_part_ki(plan.bar_k_part1_ready[cur_buf], D_sQ/64, D_K/64);
} else {
plan.bar_k_part1_ready[cur_buf].complete_transaction(0u, NUM_LOCAL_ROWS_PER_WARP*4*D_tQ*sizeof(bf16), 1u);
}
}
}
} else if (warpgroup_idx == 2) {
// Producer warps for V
cutlass::arch::warpgroup_reg_dealloc<96>();
int warp_idx = cutlass::canonical_warp_idx_sync() - 8;
constexpr int NUM_WARPS = 4;
if (elect_one_sync()) {
// Wait for UTCCP
plan.bar_prologue_utccp.wait(0);
bf16* sV_base = plan.u.s.v.data() + warp_idx*4*64;
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
auto load_part_vi = [&](transac_bar_t &bar, int local_row_start, int local_row_end) {
CUTE_UNROLL
for (int local_row = local_row_start; local_row < local_row_end; ++local_row) {
int4 token_idxs = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx);
CUTE_UNROLL
for (int local_col = 0; local_col < (D_V/2)/64; ++local_col)
ku::tma_gather4_cta_group_2<true>(
&(tma_params.tensor_map_kv),
bar,
sV_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64),
local_col*64 + (cta_idx?256:0),
token_idxs,
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
};
int cur_buf = k%NUM_BUFS;
if (k > 0) {
plan.bar_sv_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
load_part_vi(plan.bar_v_part0_ready[cur_buf], 0, (B_TOPK/2)/4/NUM_WARPS);
if (k > 0) {
plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
load_part_vi(plan.bar_v_part1_ready[cur_buf], (B_TOPK/2)/4/NUM_WARPS, B_TOPK/4/NUM_WARPS);
}
}
} else {
cutlass::arch::warpgroup_reg_alloc<168>();
// MMA warp
if (cta_idx == 0 && warp_idx == 12 && elect_one_sync()) {
// S -> T copy for Q
UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(plan.u.q_full.data() + (B_H/2)*D_sQ),
tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64>>{}
)
)
);
plan.bar_prologue_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16));
plan.bar_prologue_q.wait(0);
ku::tcgen05_after_thread_sync();
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < NUM_tQ_TILES; ++tile_idx) {
// A tile is 64 rows * 64 cols (128B)
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 8; ++subtile_idx) {
// A subtile is 64 rows * 8 cols (128b)
SM100_UTCCP_2x64dp128bitlw0213_2cta::copy(
sQ_desc + tile_idx*((B_H/2)*128/16) + subtile_idx*(16/16), // Remember that 4 LSBs are not included
tmem_cols::q + tile_idx*32 + subtile_idx*4
);
}
}
ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_prologue_utccp, 1|2);
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks+1; ++k) {
if (k < num_k_blocks) {
// Pi = QKi^T
int cur_buf = k%NUM_BUFS;
Tensor sQl = make_tensor(make_smem_ptr(plan.u.s.sq.data()), SmemLayoutQTiles<NUM_sQ_TILES>{});
Tensor sKl = make_tensor(make_smem_ptr(plan.u.s.k.data()), SmemLayoutKTiles<NUM_sQ_TILES>{});
Tensor sKr = make_tensor(make_smem_ptr(plan.u.s.k.data()+64*D_sQ), SmemLayoutKTiles<NUM_tQ_TILES>{});
// Wait for K (part0)
plan.bar_k_part0_ready[cur_buf].arrive_and_expect_tx(B_TOPK*D_sQ*sizeof(bf16));
plan.bar_k_part0_ready[cur_buf].wait((k/NUM_BUFS)&1);
if (k > 0) {
plan.bar_p_free[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
ku::tcgen05_after_thread_sync();
ku::utcmma_ss(tiled_mma_P_sQ, sQl, sKl, tP, true);
ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_part_done[cur_buf], 1|2);
// Wait for K (part1)
plan.bar_k_part1_ready[cur_buf].arrive_and_expect_tx(B_TOPK*(D_K-D_sQ)*sizeof(bf16));
plan.bar_k_part1_ready[cur_buf].wait((k/NUM_BUFS)&1);
ku::tcgen05_after_thread_sync();
ku::utcmma_ts(tiled_mma_P_tQ, tQr, sKr, tP, false);
ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_done[cur_buf], 1|2);
}
if (k > 0) {
// O += S(i-1)V(i-1)
int cur_buf = (k-1)%NUM_BUFS;
Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutSTiles<2>{});
Tensor sV = make_tensor(make_smem_ptr(plan.u.s.v.data()), SmemLayoutV{});
Tensor sS_divided = flat_divide(sS, Tile<Int<B_H/2>, _64>{})(_, _, _0{}, _); // (B_H/2, 64, 2)
Tensor sV_divided = flat_divide(sV, Tile<Int<D_V/2>, _64>{})(_, _, _0{}, _); // (D_V/2, 64, 2)
// Wait for S(i-1) and O to be scaled
plan.bar_so_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);
// Wait for V (part0), and issue O += sS @ sV
plan.bar_v_part0_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16));
plan.bar_v_part0_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);
ku::tcgen05_after_thread_sync();
ku::utcmma_ss(tiled_mma_O, sS_divided(_, _, _0{}), sV_divided(_, _, _0{}), tO, k == 1);
ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_part_done[cur_buf], 1|2);
// Wait for V (part1), and issue O += sS @ sV
plan.bar_v_part1_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16));
plan.bar_v_part1_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);
ku::tcgen05_after_thread_sync();
ku::utcmma_ss(tiled_mma_O, sS_divided(_, _, _1{}), sV_divided(_, _, _1{}), tO, false);
ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_done[cur_buf], 1|2);
}
}
} else if (warp_idx == 13) {
// KV valid loading warp
static_assert(B_TOPK == 128);
if (lane_idx < 16) {
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
int cur_buf = k%NUM_BUFS;
int32x8_t indices = ldg_256_indices(gIndices + k*B_TOPK + lane_idx*8);
auto is_valid = [&](int rel_pos_in_lane, int index) -> char {
int abs_pos = k*B_TOPK + lane_idx*8 + rel_pos_in_lane;
return index >= 0 && index < params.s_kv && abs_pos < topk_length;
};
char is_ks_valid_mask = \
is_valid(7, indices.a7) << 7 |
is_valid(6, indices.a6) << 6 |
is_valid(5, indices.a5) << 5 |
is_valid(4, indices.a4) << 4 |
is_valid(3, indices.a3) << 3 |
is_valid(2, indices.a2) << 2 |
is_valid(1, indices.a1) << 1 |
is_valid(0, indices.a0) << 0;
plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1);
plan.is_k_valid[cur_buf][lane_idx] = is_ks_valid_mask;
plan.bar_k_valid_ready[cur_buf].arrive();
}
}
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100");
}
#endif
}
template<typename Kernel, typename TmaParams>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2)
sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TmaParams tma_params) {
Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params);
}
template<int D_QK>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
static_assert(D_QK == 576 || D_QK == 512);
using Kernel = KernelTemplate<D_QK>;
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk % Kernel::B_TOPK == 0); // To save some boundry checkings
KU_ASSERT(params.h_q == Kernel::B_H); // To save some calculation
KU_ASSERT(params.d_qk == D_QK);
auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q);
auto tma_Q = cute::make_tma_copy(
SM100_TMA_2SM_LOAD_NOSPLIT{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
(typename Kernel::template SmemLayoutQTiles<D_QK/64>){}
);
auto shape_O = make_shape(params.h_q, params.d_v, params.s_q);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((bf16*)params.out),
make_layout(
shape_O,
make_stride(params.d_v, _1{}, params.h_q*params.d_v)
)
),
(typename Kernel::template SmemLayoutOTiles<1>){}
);
CUtensorMap tensor_map_kv;
{
uint64_t size[2] = {D_QK, (unsigned long)params.s_kv};
uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)};
uint32_t box_size[2] = {64, 1};
uint32_t elem_stride[2] = {1, 1};
CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tensor_map_kv,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
2,
params.kv,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
KU_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
decltype(shape_Q), decltype(tma_Q),
decltype(shape_O), decltype(tma_O)
> tma_params = {
shape_Q, tma_Q,
shape_O, tma_O,
tensor_map_kv
};
auto kernel = &sparse_attn_fwd_kernel<Kernel, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(typename Kernel::SharedMemoryPlan);
KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
cutlass::ClusterLaunchParams launch_params = {
dim3(2*params.s_q, 1, 1),
dim3(Kernel::NUM_THREADS, 1, 1),
dim3(2, 1, 1),
smem_size,
params.stream
};
KU_CUTLASS_CHECK(cutlass::launch_kernel_on_cluster(
launch_params, (void*)kernel, params, tma_params
));
}
}
#pragma once
#include "params.h"
namespace sm100::fwd::head128 {
template<int D_QK>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment