Unverified Commit 41b611f7 authored by Zeyu WANG's avatar Zeyu WANG Committed by GitHub
Browse files

Add more GPU architctures support (#76)



* Add more GPU architctures support

* Merge fmha and mla runner

* add varlen & non varlen support, and add incontiguous tensor support

* update readme

* add varlen api

---------
Co-authored-by: default avatardianzhangc <dianzhangc@nvidia.com>
parent 9edee0c0
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
// Swizzle Q tile and H tile to improve L2 cache hit rate,
// and launch the longest main loop first to keep most SMs busy.
struct CausalIndividualTileScheduler {
static constexpr int TileQ = 16;
static constexpr int TileH = 8;
static constexpr int TileSize = TileQ * TileH;
struct Params {
dim3 grid;
int tile_max_q;
FastDivmod divmod_tile_col;
FastDivmod divmod_tile_size;
FastDivmod divmod_tile_head;
};
bool valid_ = true;
Params params;
CUTLASS_DEVICE
CausalIndividualTileScheduler(Params const& params) : params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size));
// gridDim.x must multiple of TileH
const int tile_col_count = grid.x / TileH;
const int tile_max_q = grid.y / TileQ * TileQ;
return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH};
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
const int block_idx = blockIdx.y * gridDim.x + blockIdx.x;
int tile_idx, tile_tail;
params.divmod_tile_size(tile_idx, tile_tail, block_idx);
int tile_row_idx, tile_col_idx;
params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx);
int row_offset_in_tail, col_offset_in_tail;
params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail);
const int row_idx = tile_row_idx * TileQ + row_offset_in_tail;
const int col_idx = tile_col_idx * TileH + col_offset_in_tail;
// last q tile launch first
if(blockIdx.y >= params.tile_max_q) {
return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z)));
}
return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z)));
}
CUTLASS_DEVICE
CausalIndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
// Launch order: H Q B
struct CausalPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_h;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);
return Params {
num_blocks,
{ size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) },
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, bidh;
params.divmod_h(block_decode, bidh, block_decode);
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
}
CUTLASS_DEVICE
CausalPersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<class ProblemShape, class Element, class ElementAcc>
struct FmhaKernelBwdConvert {
struct Arguments {
ProblemShape problem_shape;
const ElementAcc* ptr_src_dQ;
tuple<int, _1, tuple<int, int>> stride_src_dQ;
const ElementAcc* ptr_src_dK;
tuple<int, _1, tuple<int, int>> stride_src_dK;
const ElementAcc* ptr_src_dV;
tuple<int, _1, tuple<int, int>> stride_src_dV;
Element* ptr_dest_dQ;
tuple<int, _1, tuple<int, int>> stride_dest_dQ;
Element* ptr_dest_dK;
tuple<int, _1, tuple<int, int>> stride_dest_dK;
Element* ptr_dest_dV;
tuple<int, _1, tuple<int, int>> stride_dest_dV;
ElementAcc scale = 1.0;
};
using Params = Arguments;
using ClusterShape = Shape<_1, _1, _1>;
static constexpr int SharedStorageSize = 0;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = 128;
using ArchTag = cutlass::arch::Sm90;
static const int kBlockSeq = 8;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static const int kNumThreadsD = 16;
static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;
static const int kElementsPerLoad = 4;
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));
return grid;
}
static dim3 get_block_shape() {
dim3 block(kNumThreadsD, kNumThreadsSeq, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return args;
}
template<class StrideSrc, class StrideDest, class Count>
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) {
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
int seqlen = count;
if constexpr (is_variable_length_v<decltype(count)>) {
int offset = count.cumulative_length[blockIdx.y];
ptr_dest_bh += offset * get<0>(stride_dest);
seqlen = count.cumulative_length[blockIdx.y + 1] - offset;
}
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
if (idx_s >= seqlen) continue;
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) {
ElementAcc value_src[kElementsPerLoad];
Element value_dest[kElementsPerLoad];
using VecSrc = uint_bit_t<sizeof_bits_v<ElementAcc> * kElementsPerLoad>;
using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
*reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) {
value_dest[v] = static_cast<Element>(params.scale * value_src[v]);
}
*reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);
}
}
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape));
}
if (params.ptr_src_dK != nullptr) {
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape));
}
if (params.ptr_src_dV != nullptr) {
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape));
}
}
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<class ProblemShape, class Element, class ElementAcc>
struct FmhaKernelBwdSumOdO {
struct Arguments {
ProblemShape problem_shape;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
ElementAcc* ptr_sum_OdO;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_sum_OdO;
const ElementAcc* ptr_lse = nullptr;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_lse;
ElementAcc* ptr_scaled_lse = nullptr;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_scaled_lse;
ElementAcc sum_odo_scale = 1.0;
ElementAcc lse_scale = 1.0;
};
using Params = Arguments;
using ClusterShape = Shape<_1, _1, _1>;
static constexpr int SharedStorageSize = 0;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = 128;
using ArchTag = cutlass::arch::Sm100;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static const int kBlockQ = 16;
static const int kNumThreadsD = 8;
static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;
static const int kElementsPerLoad = 2;
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape));
return grid;
}
static dim3 get_block_shape() {
dim3 block(kNumThreadsD, kNumThreadsQ, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return args;
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
auto problem_q = get<0>(params.problem_shape);
int seqlen_q = problem_q;
if constexpr (is_variable_length_v<decltype(problem_q)>) {
int offset = problem_q.cumulative_length[blockIdx.z];
ptr_O_bh += offset * get<0>(params.stride_O);
ptr_dO_bh += offset * get<0>(params.stride_dO);
ptr_lse_bh += offset * get<0>(params.stride_lse);
seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset;
}
CUTLASS_PRAGMA_UNROLL
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
if (idx_q >= seqlen_q) continue;
ElementAcc acc = 0;
auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);
auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO);
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
Element value_O[kElementsPerLoad];
Element value_dO[kElementsPerLoad];
using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
*reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) {
acc += value_O[v] * value_dO[v];
}
}
for (int i = 1; i < kNumThreadsD; i *= 2) {
acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);
}
if (threadIdx.x == 0) {
*ptr_sum_OdO_bhq = params.sum_odo_scale * acc;
if (params.ptr_scaled_lse) {
*ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq;
}
}
}
}
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
namespace cutlass::fmha::kernel {
template<auto kTag, typename Default, typename... Options>
struct find_option;
template<auto kTag, typename Default>
struct find_option<kTag, Default> {
using option_value = Default;
};
template<auto kTag, typename Default, typename Option, typename... Options>
struct find_option<kTag, Default, Option, Options...> :
std::conditional_t<
Option::tag == kTag,
Option,
find_option<kTag, Default, Options...>
>
{};
template<auto kTag, typename Default, typename... Options>
using find_option_t = typename find_option<kTag, Default, Options...>::option_value;
enum class Tag {
kIsPersistent,
kNumMmaWarpGroups,
kLoadsQSeparately,
kIsMainloopLocked,
kIsEpilogueLocked,
kStagesQ,
kStagesKV,
kEpilogueKind,
kBlocksPerSM,
kClusterM,
kAccQK
};
template<auto kTag, class Value>
struct Option {
static constexpr auto tag = kTag;
using option_value = Value;
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct IndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
IndividualTileScheduler(Params const&) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size));
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
}
CUTLASS_DEVICE
IndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct PersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_h;
FastDivmod divmod_b;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);
return Params {
num_blocks,
{ num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) },
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, bidh;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_h(block_decode, bidh, block_decode);
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
}
CUTLASS_DEVICE
PersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "collective/fmha_common.hpp"
#include <cmath>
namespace cutlass::fmha::kernel {
using namespace cutlass::fmha::collective;
using namespace cute;
template<
class ProblemShape,
class Element,
class ElementAcc,
class TileShape,
class Mask
>
struct Sm100FmhaBwdKernelTmaWarpSpecialized {
using TileShapeQ = decltype(get<0>(TileShape{}));
static_assert(std::is_same_v<TileShapeQ, _128>, "tile shape K must be 128");
using TileShapeK = decltype(get<1>(TileShape{}));
static_assert(std::is_same_v<TileShapeK, _128>, "tile shape K must be 128");
using TileShapeDQK = decltype(get<2>(TileShape{}));
using TileShapeDVO = decltype(get<2>(TileShape{}));
using TmemAllocator = cute::TMEM::Allocator1Sm;
struct TmemAllocation {
static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc
static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc
static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc
static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp
static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{});
static constexpr uint32_t kP = kS;
static constexpr uint32_t kTotal = kS + TileShapeQ{};
};
static_assert(
static_cast<int>(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns,
"using too much tmem"
);
enum class WarpRole {
Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4
};
static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull;
static constexpr int kNumComputeWarps = 8;
static constexpr int kNumReduceWarps = 4;
CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) {
return static_cast<WarpRole>((kWarpAssignment >> (4 * warp_idx)) & 0xF);
}
struct RegisterAllocation {
static constexpr int kWarpgroup0 = 160-8;
static constexpr int kWarpgroup1 = 128;
static constexpr int kWarpgroup2 = 96;
static constexpr int kReduce = kWarpgroup0;
static constexpr int kCompute = kWarpgroup1;
static constexpr int kMma = kWarpgroup2;
static constexpr int kEmpty = kWarpgroup2;
static constexpr int kLoad = kWarpgroup2;
static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512);
};
using ArchTag = cutlass::arch::Sm100;
using ClusterShape = Shape<_1, _1, _1>;
using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100;
static constexpr int MinBlocksPerMultiprocessor = 1;
static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4;
static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps;
static constexpr int Alignment = 128 / sizeof_bits_v<Element>;
static constexpr int kStages = 2;
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
// compute S
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeQ, TileShapeDQK>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeKQ = typename CollectiveMmaKQ::TileShape;
using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma;
// compute dP
using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeQ, TileShapeDVO>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeVDO = typename CollectiveMmaVDO::TileShape;
using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma;
// compute dV
using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// needs to match ordering of S calculation
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeDVO, TileShapeQ>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapePDO = typename CollectiveMmaPDO::TileShape;
using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{}));
// compute dK
using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// somewhat arbitrary since we dump to smem, need to agree with the next one
Element, TensorStrideContiguousK , Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeDQK, TileShapeQ>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape;
using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma;
// compute dQ
using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// somewhat arbitrary since we dump to smem, need to agree with the previous one
Element, TensorStrideContiguousMN, Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeQ, TileShapeDQK, TileShapeK>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeDSK = typename CollectiveMmaDSK::TileShape;
using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma;
// pipelines are named Pipeline<Producer><Consumer><Resource>
static constexpr int kStagesComputeSmem = 1;
using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>;
using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>;
using PipelineLoadComputeLSE = PipelineAsync<1>;
using PipelineLoadComputeSumOdO = PipelineAsync<1>;
using PipelineMmaComputeS = PipelineUmmaAsync<1>;
using PipelineMmaComputeDP = PipelineUmmaAsync<1>;
using PipelineMmaReduceDQ = PipelineUmmaAsync<1>;
using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>;
using PipelineComputeMmaDS = PipelineUmmaConsumerAsync<kStagesComputeSmem>;
using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>;
static constexpr int kStagesReduceTmaStore = 2;
using PipelineReduceTmaStore = PipelineTmaStore<kStagesReduceTmaStore>;
struct PipelineStorage {
alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q;
alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do;
alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse;
alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo;
alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s;
alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp;
alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq;
alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p;
alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds;
alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv;
};
template<class Layout, class Stages = _1>
static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) {
return composition(layout, make_tuple(_, _, _, make_layout(stages)));
}
using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{}));
using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{}));
using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{}));
using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{}));
using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int<kStagesComputeSmem>{}));
using SmemLayoutLSE = Layout<Shape<TileShapeQ, _1>>;
using SmemLayoutSumOdO = Layout<Shape<TileShapeQ, _1>>;
using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{}));
using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{}));
using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int<kStagesComputeSmem>{}));
using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{}));
using TileShapeDQ = _32;
using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ
>());
using SmemShapeDQ = Shape<TileShapeQ, TileShapeDQ, Int<kStagesReduceTmaStore>>;
using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{}));
struct TensorStorage {
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutK>> smem_k;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKT>> smem_k_t;
};
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutV>> smem_v;
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQT>> smem_q_t;
};
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDO>> smem_do;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDOT>> smem_do_t;
};
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDST>> smem_ds_t;
};
alignas(1024) cute::array<ElementAcc, cute::cosize_v<SmemLayoutDQ>> smem_dq;
alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutLSE>> smem_lse;
alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutSumOdO>> smem_sum_odo;
};
static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
struct SharedStorage {
TensorStorage tensors;
PipelineStorage pipelines;
uint32_t tmem_base_ptr;
};
// this is tight enough that it won't work with sizeof due to padding for alignment
static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
using TensorStride = TensorStrideContiguousK; // S D (H B)
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
struct MainloopArguments {
const Element* ptr_q;
TensorStride stride_q;
const Element* ptr_k;
TensorStride stride_k;
const Element* ptr_v;
TensorStride stride_v;
const Element* ptr_do;
TensorStride stride_do;
const ElementAcc* ptr_lse;
RowTensorStride stride_lse;
const ElementAcc* ptr_sum_odo;
RowTensorStride stride_sum_odo;
ElementAcc* ptr_dq_acc;
TensorStride stride_dq_acc;
ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{});
};
using TMA_K = typename CollectiveMmaKQ::Params::TMA_A;
using TMA_V = typename CollectiveMmaVDO::Params::TMA_A;
using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B;
using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B;
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}),
SmemLayoutDQ{}(_, _, _0{})
));
struct MainloopParams {
TMA_K tma_load_k;
TMA_V tma_load_v;
TMA_Q tma_load_q;
TMA_DO tma_load_do;
TMA_DQ tma_red_dq;
};
struct EpilogueArguments {
Element* ptr_dk;
TensorStride stride_dk;
Element* ptr_dv;
TensorStride stride_dv;
};
struct Arguments {
ProblemShape problem_shape;
MainloopArguments mainloop;
EpilogueArguments epilogue;
KernelHardwareInfo hw_info;
};
struct Params {
ProblemShape problem_shape;
MainloopArguments mainloop;
MainloopParams mainloop_params;
EpilogueArguments epilogue;
KernelHardwareInfo hw_info;
};
static bool can_implement(Arguments const& args) {
auto [Q, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) {
return false;
}
if (D % Alignment != 0 || D_VO % Alignment != 0) {
return false;
}
return true;
}
static Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return Status::kSuccess;
}
static Params to_underlying_arguments(Arguments const& args, void*) {
auto [Q_, K_, D, D_VO, HB] = args.problem_shape;
int Q = Q_;
int K = K_;
if constexpr (is_variable_length_v<decltype(Q_)>) {
Q = Q_.total_length;
}
if constexpr (is_variable_length_v<decltype(K_)>) {
K = K_.total_length;
}
auto params_kq = CollectiveMmaKQ::to_underlying_arguments(
make_shape(K, Q, D, HB),
typename CollectiveMmaKQ::Arguments {
args.mainloop.ptr_k, args.mainloop.stride_k,
args.mainloop.ptr_q, args.mainloop.stride_q,
}, /*workspace=*/nullptr);
auto params_vdo = CollectiveMmaVDO::to_underlying_arguments(
make_shape(K, Q, D_VO, HB),
typename CollectiveMmaVDO::Arguments {
args.mainloop.ptr_v, args.mainloop.stride_v,
args.mainloop.ptr_do, args.mainloop.stride_do,
}, /*workspace=*/nullptr);
TMA_DQ tma_red_dq = make_tma_copy(
SM90_TMA_REDUCE_ADD{},
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc),
SmemLayoutDQ{}(_, _, _0{})
);
return Params{
args.problem_shape,
args.mainloop,
MainloopParams{
params_kq.tma_load_a,
params_vdo.tma_load_a,
params_kq.tma_load_b,
params_vdo.tma_load_b,
tma_red_dq
},
args.epilogue,
args.hw_info
};
}
template<class T>
static CUTLASS_DEVICE auto quantize(T const& input) {
constexpr int AlignmentS = 4;
auto output = make_tensor<Element>(shape(input));
auto input_vec = recast<Array<ElementAcc, AlignmentS>>(input);
auto output_vec = recast<Array<Element, AlignmentS>>(output);
cutlass::NumericArrayConverter<Element, ElementAcc, AlignmentS> epilogue_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(input_vec); i++) {
output_vec(i) = epilogue_op(input_vec(i));
}
return output;
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void load(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state,
PipelineLoadComputeLSE& pipeline_load_compute_lse,
typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state,
PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
using X = Underscore;
uint16_t mcast_mask = 0;
auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB));
auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB));
auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in);
auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in);
auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in);
auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in);
auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step<X, _1, _1>{});
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
auto tSTgK = cta_mma_kq.partition_A(gK);
auto tSTgQ = cta_mma_kq.partition_B(gQ);
auto tDPTgV = cta_mma_vdo.partition_A(gV);
auto tDPTgDO = cta_mma_vdo.partition_B(gDO);
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});
auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});
auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});
auto [tKgK_mkl, tKsK] = tma_partition(
mainloop_params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSTgK));
auto [tQgQ_mkl, tQsQ] = tma_partition(
mainloop_params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ));
auto [tVgV_mkl, tVsV] = tma_partition(
mainloop_params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tDPTgV));
auto [tDOgDO_mkl, tDOsDO] = tma_partition(
mainloop_params.tma_load_do, _0{}, make_layout(_1{}),
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
// set up lse and sum_odo
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK);
// load K
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask),
tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch),
tKsK(_, _0{})
);
}
// load Q
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
tQsQ(_, pipeline_load_mma_q_producer_state.index())
);
}
++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE
// 32 threads loading 128 values of 32b each
// so 4*32b=128b
int thread_idx = threadIdx.x % NumThreadsPerWarp;
int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
// load V
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask),
tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch),
tVsV(_, _0{})
);
}
// load dO
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
tDOsDO(_, pipeline_load_mma_do_producer_state.index())
);
}
++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
iter_count -= 1;
iter_index += 1;
while (iter_count > 0) {
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
// load Q
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
tQsQ(_, pipeline_load_mma_q_producer_state.index())
);
}
++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
// load dO
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
tDOsDO(_, pipeline_load_mma_do_producer_state.index())
);
}
++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
for (int i = 0; i < 4; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
iter_count -= 1;
iter_index += 1;
}
}
template<class BlkCoord, class ProblemShape_>
CUTLASS_DEVICE void mma(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});
auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});
auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});
auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{});
auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{});
auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{});
auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{});
auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{});
auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{});
Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK);
Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ);
Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV);
Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO);
Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS);
Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT);
Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST);
Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT);
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
tDVrP.data() = TmemAllocation::kP;
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
TiledMmaKQ tiled_mma_kq;
TiledMmaVDO tiled_mma_vdo;
TiledMmaDSK tiled_mma_dsk;
TiledMmaDSQ tiled_mma_dsq;
TiledMmaPDO tiled_mma_pdo;
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero;
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero;
Tensor tSTtST = partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{}));
tSTtST.data() = TmemAllocation::kS;
Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{}));
tDPTtDPT.data() = TmemAllocation::kDP;
Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{}));
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{}));
tDKtDK.data() = TmemAllocation::kDK;
Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{}));
tDVtDV.data() = TmemAllocation::kDV;
auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state;
pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);
pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);
// S = Q*K
tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {
cute::gemm(tiled_mma_kq,
tSTrK(_,_,k_block,_0{}),
tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),
tSTtST);
tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One;
}
++pipeline_load_mma_q_consumer_state;
pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);
++pipeline_mma_compute_s_producer_state;
pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);
pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);
pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);
// dP = dO*V
tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {
cute::gemm(tiled_mma_vdo,
tDPTrV(_,_,k_block,_0{}),
tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDPTtDPT);
tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);
++pipeline_mma_compute_dp_producer_state;
pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);
// dV = P*dO
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {
cute::gemm(tiled_mma_pdo,
tDVrP(_,_,k_block),
tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDVtDV);
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);
++pipeline_compute_mma_p_consumer_state;
pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);
++pipeline_load_mma_do_consumer_state;
iter_count -= 1;
// in tmem, S & P overlap
// and dP and dQ overlap
// so we need to acquire dQ and dP at the same time
while (iter_count > 0) {
pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);
pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);
// S = Q*K
tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {
cute::gemm(tiled_mma_kq,
tSTrK(_,_,k_block,_0{}),
tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),
tSTtST);
tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One;
}
++pipeline_load_mma_q_consumer_state;
pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);
++pipeline_mma_compute_s_producer_state;
pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);
// we need to acquire dP here, because tmem dQ == tmem dP
pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);
// dQ = dS*K
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {
cute::gemm(tiled_mma_dsk,
tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDQrKT(_,_,k_block,_0{}),
tDQtDQ);
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);
++pipeline_mma_reduce_dq_producer_state;
// dK = dS*Q
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {
cute::gemm(tiled_mma_dsq,
tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),
tDKtDK);
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);
++pipeline_load_mma_q_release_state;
pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);
++pipeline_compute_mma_ds_consumer_state;
// we grab dq here, because in tmem dq == dp
pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);
pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);
// dP = dO*V
tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {
cute::gemm(tiled_mma_vdo,
tDPTrV(_,_,k_block,_0{}),
tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDPTtDPT);
tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);
++pipeline_mma_compute_dp_producer_state;
pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);
// dV = P*dO
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {
cute::gemm(tiled_mma_pdo,
tDVrP(_,_,k_block),
tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDVtDV);
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);
++pipeline_compute_mma_p_consumer_state;
pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);
++pipeline_load_mma_do_consumer_state;
iter_count -= 1;
}
// signal to the epilogue that dV is ready
pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);
pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);
++pipeline_mma_compute_dkdv_producer_state;
pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);
pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);
// dK = dS*Q
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {
cute::gemm(tiled_mma_dsq,
tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),
tDKtDK);
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;
}
// signal to epilgue that dK is ready
pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);
++pipeline_mma_compute_dkdv_producer_state;
// we've already acquired mma_reduce_dq in the loop
// dQ = dS*K
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {
cute::gemm(tiled_mma_dsk,
tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDQrKT(_,_,k_block,_0{}),
tDQtDQ);
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);
++pipeline_mma_reduce_dq_producer_state;
pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);
++pipeline_load_mma_q_release_state;
pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);
++pipeline_compute_mma_ds_consumer_state;
}
template<class TensorG, class TensorR, class TensorC, class TensorShape>
CUTLASS_DEVICE void store(
TensorG gmem,
TensorR const& regs,
TensorC const& coord,
TensorShape const& tensor_shape) {
//TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version.
// Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
regs.layout()
);
auto thr_copy = copy_op.get_slice(_0{});
Tensor quantized_regs = quantize(regs);
auto tCg = thr_copy.partition_D(gmem);
auto tCr = thr_copy.partition_S(quantize(regs));
auto tCc = thr_copy.partition_D(coord);
constexpr int R = decltype(tCr.layout())::rank;
auto tCg_v = group_modes<1, R>(tCg);
auto tCr_v = group_modes<1, R>(tCr);
auto tCc_v = group_modes<1, R>(tCc);
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
for (int i = 0; i < size(tCp_v); ++i) {
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
}
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void epilogue_clear(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDK = domain_offset(
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
);
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) {
if (elem_less(cDK(i), select<1,2>(problem_shape))) {
gDK(i) = Element(0);
}
}
for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) {
if (elem_less(cDV(i), select<1,3>(problem_shape))) {
gDV(i) = Element(0);
}
}
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void epilogue(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
tDKtDK.data() = TmemAllocation::kDK;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDK = domain_offset(
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
);
constexpr int kNumWarpgroups = kNumComputeWarps / 4;
int dp_idx = threadIdx.x % 128;
int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;
auto split_wg = [&](auto const& t) {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, make_coord(wg_idx, _));
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, _, make_coord(wg_idx, _));
}
};
auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK);
auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx);
Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK));
Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK));
Tensor tTR_rDK = make_tensor<ElementAcc>(shape(tTR_cDK));
Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK));
auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
tDVtDV.data() = TmemAllocation::kDV;
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV);
auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx);
Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV));
Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV));
Tensor tTR_rDV = make_tensor<ElementAcc>(shape(tTR_cDV));
Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV));
pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);
// load tDVtDV
cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV);
// store tDVgDV
store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape));
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);
++pipeline_mma_compute_dkdv_consumer_state;
pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);
// load tDKtDK
cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rDK); i++) {
tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i);
}
// store tDKgDK
store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape));
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);
++pipeline_mma_compute_dkdv_consumer_state;
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void compute(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
TensorStorage& shared_tensors,
PipelineLoadComputeLSE& pipeline_load_compute_lse,
typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state,
PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state,
PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
// in tmem, S & P overlap
// and dP and dQ overlap
// there are two compute wg's that cooperatively compute softmax
// they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
auto store_op = []() {
if constexpr (sizeof(Element) == 1) {
return SM100_TMEM_STORE_32dp32b4x{};
}
else {
return SM100_TMEM_STORE_32dp32b8x{};
}
}();
Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{});
tSTtST.data() = TmemAllocation::kS;
Tensor tDPTtDPT = partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{});
tDPTtDPT.data() = TmemAllocation::kDP;
Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{}));
Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{}));
constexpr int kNumWarpgroups = kNumComputeWarps / 4;
int dp_idx = threadIdx.x % 128;
int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;
auto tiled_t2r = make_tmem_copy(load_op, tSTtST);
auto thread_t2r = tiled_t2r.get_slice(dp_idx);
auto split_wg = [&](auto const& t) {
if constexpr (decltype(size<1>(t))::value > 1) {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t))));
return p(_, make_coord(wg_idx, _), _);
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t), size<3>(t))));
return p(_, make_coord(wg_idx, _), _, _);
}
}
else {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, make_coord(wg_idx, _));
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, _, make_coord(wg_idx, _));
}
}
};
Tensor tTR_cST_p = thread_t2r.partition_D(cST);
Tensor tTR_cST = split_wg(tTR_cST_p);
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT));
Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{});
Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{});
auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{});
auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST);
tDVrP.data() = TmemAllocation::kP;
auto tiled_r2t = make_tmem_copy(store_op, tDVrP);
auto thread_r2t = tiled_r2t.get_slice(dp_idx);
auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP));
auto tRT_cST_p = thread_r2t.partition_S(tDVcST);
auto tRT_cST = split_wg(tRT_cST_p);
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
int last_iter = iter_count - 1 + iter_index;
CUTLASS_PRAGMA_NO_UNROLL
while (iter_count > 0) {
// wait for S and P
pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state);
pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state);
// wait for LSE
pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state);
auto dispatch_bool = [](bool b, auto fn) {
if (b) {
fn(cute::true_type{});
}
else {
fn(cute::false_type{});
}
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
int kv_left = get<1>(blk_coord) * TileShapeK{};
int kv_right = kv_left + TileShapeK{} - 1;
int q_left = iter_index * TileShapeQ{} + offset;
int q_right = q_left + TileShapeQ{} - 1;
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
}
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
// compute P = softmax(S, LSE)
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
if constexpr (decltype(is_masked_tile)::value) {
Mask{}.apply_mask(tTR_rST, [&](int i) {
auto c_transpose = tTR_cST(i);
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
}, problem_shape);
}
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
float2 softmax_scale_log2_e;
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rST); i += 2) {
float2 acc;
float2 lse;
float2 out;
acc.x = tTR_rST(i);
acc.y = tTR_rST(i + 1);
lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index());
lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index());
cute::fma(out, softmax_scale_log2_e, acc, lse);
tTR_rST(i) = ::exp2f(out.x);
tTR_rST(i+1) = ::exp2f(out.y);
}
auto tRT_rST = quantize(tTR_rST);
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
cutlass::arch::fence_view_async_tmem_load();
cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransformBarrier
).arrive_and_wait();
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
});
// notify for P
cutlass::arch::fence_view_async_tmem_store();
pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state);
++pipeline_compute_mma_p_producer_state;
// release S
pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state);
++pipeline_mma_compute_s_consumer_state;
// release LSE
pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state);
++pipeline_load_compute_lse_consumer_state;
// wait for OdO
pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state);
// wait for dP
pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state);
// wait for dS
// in principle, we could defer waiting for dS, and move in the freeing of dP
// however, that would force us to keep dS in registers longer
pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state);
// compute dS = dsoftmax(P, dP, sum_OdO)
cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rDPT); i += 2) {
float2 st;
st.x = tTR_rST(i);
st.y = tTR_rST(i+1);
float2 dpt;
dpt.x = tTR_rDPT(i);
dpt.y = tTR_rDPT(i+1);
float2 odo;
odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index());
odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index());
float2 dif;
// sum odo is negated during preprocess
cute::add(dif, dpt, odo);
float2 out;
cute::mul(out, dif, st);
tTR_rDPT(i) = out.x;
tTR_rDPT(i+1) = out.y;
}
auto tTR_rDST = quantize(tTR_rDPT);
// release dP
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state);
++pipeline_mma_compute_dp_consumer_state;
Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{})
(_, _, _, pipeline_compute_mma_ds_producer_state.index());
auto thread_layout = make_ordered_layout(
make_shape(_128{}, _128{}),
make_stride(_1{}, _0{})
);
auto sDS_pi = as_position_independent_swizzle_tensor(sDS);
auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p)));
auto sDS_pi_slice = split_wg(sDS_pi_slice_p);
copy_aligned(tTR_rDST, sDS_pi_slice);
// notify for dS
cutlass::arch::fence_view_async_shared();
pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state);
++pipeline_compute_mma_ds_producer_state;
// release OdO
pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state);
++pipeline_load_compute_sum_odo_consumer_state;
iter_count -= 1;
iter_index += 1;
}
epilogue(
blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
);
}
template<class BlkCoord, class ProblemShape_>
CUTLASS_DEVICE void reduce(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
TensorStorage& shared_tensors,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
PipelineReduceTmaStore& pipeline_reduce_tma_store,
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
using X = Underscore;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
// must match TileShapeDQ
auto load_op = SM100_TMEM_LOAD_32dp32b32x{};
auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{});
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
(_, _, _, _0{}, blk_coord_batch);
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{});
int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp);
auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ);
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_cDQ = thread_t2r.partition_D(cDQ);
Tensor tTR_gDQ = thread_t2r.partition_D(gDQ);
Tensor tTR_sDQ = thread_t2r.partition_D(sDQ);
Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ);
auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{});
Tensor tDQsDQ = block_tma.partition_S(sDQ);
Tensor tDQcDQ = block_tma.partition_S(cDQ);
Tensor tDQgDQ = block_tma.partition_D(gDQ);
int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0;
while (iter_count > 0) {
pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state);
Tensor tTR_rDQ = make_tensor<ElementAcc>(shape(tTR_cDQ));
// load dQ from tmem to rmem
cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ);
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state);
++pipeline_mma_reduce_dq_consumer_state;
// we don't have enough smem to dump it all to smem, so we do it in stages
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<2>(tTR_cDQ); i++) {
if (lane_predicate) {
pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state);
}
// wait in all threads for the acquire to complete
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier
).arrive_and_wait();
cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index()));
// wait for the stores to all be visible to the TMA
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier
).arrive_and_wait();
if (lane_predicate) {
// launch tma store
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index));
pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);
}
++pipeline_reduce_tma_store_producer_state;
}
iter_count -= 1;
iter_index += 1;
}
}
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
if (role == WarpRole::Load && lane_predicate) {
prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor());
}
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
int initializing_warp = 0;
typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params;
if (role == WarpRole::Load) {
pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer;
}
if (role == WarpRole::Mma) {
pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer;
}
pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load);
// Also loads K in the first iteration
pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ;
pipeline_load_mma_q_params.initializing_warp = initializing_warp++;
PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params;
if (role == WarpRole::Load) {
pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer;
}
if (role == WarpRole::Mma) {
pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer;
}
pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load);
// Also loads V in the first iteration
pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO;
pipeline_load_mma_do_params.initializing_warp = initializing_warp++;
PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params;
if (role == WarpRole::Load) {
pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer;
}
pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp;
pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;
pipeline_load_compute_lse_params.initializing_warp = initializing_warp++;
PipelineLoadComputeLSE pipeline_load_compute_lse(
shared_storage.pipelines.load_compute_lse,
pipeline_load_compute_lse_params,
/*barrier init*/ cute::true_type{});
typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params;
if (role == WarpRole::Load) {
pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer;
}
pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp;
pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;
pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++;
PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo(
shared_storage.pipelines.load_compute_sum_odo,
pipeline_load_compute_sum_odo_params,
/*barrier init*/ cute::true_type{});
typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer;
}
pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_s_params.initializing_warp = initializing_warp++;
PipelineMmaComputeS pipeline_mma_compute_s(
shared_storage.pipelines.mma_compute_s,
pipeline_mma_compute_s_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer;
}
pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++;
PipelineMmaComputeDP pipeline_mma_compute_dp(
shared_storage.pipelines.mma_compute_dp,
pipeline_mma_compute_dp_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params;
if (role == WarpRole::Mma) {
pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer;
}
if (role == WarpRole::Reduce) {
pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer;
}
pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++;
PipelineMmaReduceDQ pipeline_mma_reduce_dq(
shared_storage.pipelines.mma_reduce_dq,
pipeline_mma_reduce_dq_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params;
if (role == WarpRole::Mma) {
pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer;
}
if (role == WarpRole::Compute) {
pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer;
}
pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_compute_mma_p_params.consumer_arv_count = 1;
pipeline_compute_mma_p_params.initializing_warp = initializing_warp++;
PipelineComputeMmaP pipeline_compute_mma_p(
shared_storage.pipelines.compute_mma_p,
pipeline_compute_mma_p_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params;
if (role == WarpRole::Mma) {
pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer;
}
if (role == WarpRole::Compute) {
pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer;
}
pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_compute_mma_ds_params.consumer_arv_count = 1;
pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++;
PipelineComputeMmaDS pipeline_compute_mma_ds(
shared_storage.pipelines.compute_mma_ds,
pipeline_compute_mma_ds_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer;
}
pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++;
PipelineMmaComputeDKDV pipeline_mma_compute_dkdv(
shared_storage.pipelines.mma_compute_dkdv,
pipeline_mma_compute_dkdv_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
PipelineReduceTmaStore pipeline_reduce_tma_store;
TmemAllocator tmem_allocator;
pipeline_init_arrive_relaxed(size(ClusterShape{}));
pipeline_load_mma_q.init_masks(ClusterShape{});
pipeline_load_mma_do.init_masks(ClusterShape{});
pipeline_mma_compute_s.init_masks(ClusterShape{});
pipeline_mma_compute_dp.init_masks(ClusterShape{});
pipeline_mma_reduce_dq.init_masks(ClusterShape{});
pipeline_compute_mma_p.init_masks(ClusterShape{});
pipeline_compute_mma_ds.init_masks(ClusterShape{});
pipeline_mma_compute_dkdv.init_masks(ClusterShape{});
typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state;
typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state;
typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state;
typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state;
typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state;
typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state;
typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state;
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state<decltype(pipeline_load_compute_sum_odo)>();
auto pipeline_mma_compute_s_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_s)>();
auto pipeline_mma_compute_dp_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dp)>();
auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state<decltype(pipeline_mma_reduce_dq)>();
auto pipeline_compute_mma_p_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_p)>();
auto pipeline_compute_mma_ds_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_ds)>();
auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dkdv)>();
auto pipeline_reduce_tma_store_producer_state = make_producer_start_state<decltype(pipeline_reduce_tma_store)>();
pipeline_init_wait(size(ClusterShape{}));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));
auto [problem_shape, blk_offset] = apply_variable_length_offset(
params.problem_shape,
blk_coord
);
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;
}
iter_count -= iter_start;
if (iter_count <= 0) {
epilogue_clear(
blk_coord,
blk_offset,
problem_shape,
params.mainloop,
params.epilogue
);
return;
}
if (role == WarpRole::Load) {
warpgroup_reg_set<RegisterAllocation::kLoad>();
load(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
);
}
else if (role == WarpRole::Mma) {
warpgroup_reg_set<RegisterAllocation::kMma>();
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
mma(
blk_coord,
problem_shape,
iter_start,
iter_count,
params.mainloop,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state,
pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state,
pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state
);
}
else if (role == WarpRole::Compute) {
warpgroup_reg_set<RegisterAllocation::kCompute>();
compute(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.epilogue,
shared_storage.tensors,
pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state,
pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state,
pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
);
cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
).arrive_and_wait();
if (warp_idx % kNumComputeWarps == 0) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
else if (role == WarpRole::Reduce) {
warpgroup_reg_set<RegisterAllocation::kReduce>();
reduce(
blk_coord,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state,
pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state
);
pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state);
}
else {
warpgroup_reg_set<RegisterAllocation::kEmpty>();
/* no-op */
}
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static dim3 get_grid_shape(Params const& params) {
auto [Q, K, D, D_VO, HB] = params.problem_shape;
auto [H, B] = HB;
dim3 grid(ceil_div(K, TileShapeK{}), H, B);
return grid;
}
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "collective/fmha_common.hpp"
#include <cmath>
namespace cutlass::fmha::kernel {
using namespace cutlass::fmha::collective;
using namespace cute;
template<
class ProblemShape,
class Element,
class ElementAcc,
class TileShape,
class Mask
>
struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
using TileShapeQ = decltype(get<0>(TileShape{}));
using TileShapeK = decltype(get<1>(TileShape{}));
using TileShapeDQK = decltype(get<2>(TileShape{}));
using TileShapeDVO = decltype(get<3>(TileShape{}));
using TmemAllocator = cute::TMEM::Allocator1Sm;
struct TmemAllocation {
static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc
static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc
static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc
static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp
static constexpr uint32_t kS = kDQ + 65536 * 16;
static constexpr uint32_t kP = kS;
static constexpr uint32_t kTotal = kDQ + TileShapeDQK{};
};
static_assert(
static_cast<int>(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns,
"using too much tmem"
);
enum class WarpRole {
Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4
};
static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull;
static constexpr int kNumComputeWarps = 8;
static constexpr int kNumReduceWarps = 4;
static constexpr int kLoadPerThread = TileShapeQ{} / NumThreadsPerWarp;
static_assert(TileShapeQ{} % NumThreadsPerWarp == 0, "TileShapeQ must be divisible by NumThreadsPerWarp");
CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) {
return static_cast<WarpRole>((kWarpAssignment >> (4 * warp_idx)) & 0xF);
}
struct RegisterAllocation {
static constexpr int kWarpgroup0 = 160-8;
static constexpr int kWarpgroup1 = 128;
static constexpr int kWarpgroup2 = 96;
static constexpr int kReduce = kWarpgroup0;
static constexpr int kCompute = kWarpgroup1;
static constexpr int kMma = kWarpgroup2;
static constexpr int kEmpty = kWarpgroup2;
static constexpr int kLoad = kWarpgroup2;
static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512);
};
using ArchTag = cutlass::arch::Sm100;
using ClusterShape = Shape<_1, _1, _1>;
using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100;
static constexpr int MinBlocksPerMultiprocessor = 1;
static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4;
static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps;
static constexpr int Alignment = 128 / sizeof_bits_v<Element>;
static constexpr int kStages = 2;
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
// compute S
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeQ, TileShapeK, TileShapeDQK>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeQK = typename CollectiveMmaQK::TileShape;
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
// compute dP
using CollectiveMmaDOV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeQ, TileShapeK, TileShapeDVO>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeDOV = typename CollectiveMmaDOV::TileShape;
using TiledMmaDOV = typename CollectiveMmaDOV::TiledMma;
// compute dV
using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// needs to match ordering of S calculation
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeDVO, TileShapeQ>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapePDO = typename CollectiveMmaPDO::TileShape;
using TiledMmaPDO = typename CollectiveMmaPDO::TiledMma;
// compute dK
using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// somewhat arbitrary since we dump to smem, need to agree with the next one
Element, TensorStrideContiguousK , Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeDQK, TileShapeQ>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape;
using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma;
// compute dQ
using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// somewhat arbitrary since we dump to smem, need to agree with the previous one
Element, TensorStrideContiguousMN, Alignment,
Element, TensorStrideContiguousMN, Alignment,
ElementAcc,
Shape<TileShapeQ, TileShapeDQK, TileShapeK>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
Schedule>::CollectiveOp;
using TileShapeDSK = typename CollectiveMmaDSK::TileShape;
using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma;
// pipelines are named Pipeline<Producer><Consumer><Resource>
static constexpr int kStagesComputeSmem = 1;
using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>;
using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>;
using PipelineLoadComputeLSE = PipelineAsync<1>;
using PipelineLoadComputeSumOdO = PipelineAsync<1>;
using PipelineMmaComputeS = PipelineUmmaAsync<1>;
using PipelineMmaComputeDP = PipelineUmmaAsync<1>;
using PipelineMmaReduceDQ = PipelineUmmaAsync<1>;
using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>;
using PipelineComputeMmaDS = PipelineUmmaConsumerAsync<kStagesComputeSmem>;
using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>;
static constexpr int kStagesReduceTmaStore = 2;
using PipelineReduceTmaStore = PipelineTmaStore<kStagesReduceTmaStore>;
struct PipelineStorage {
alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q;
alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do;
alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse;
alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo;
alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s;
alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp;
alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq;
alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p;
alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds;
alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv;
};
template<class Layout, class Stages = _1>
static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) {
return composition(layout, make_tuple(_, _, _, make_layout(stages)));
}
using SmemLayoutK = decltype(restage(typename CollectiveMmaQK::SmemLayoutB{}));
using SmemLayoutV = decltype(restage(typename CollectiveMmaDOV::SmemLayoutB{}));
using SmemLayoutQ = decltype(restage(typename CollectiveMmaQK::SmemLayoutA{}, _2{}));
using SmemLayoutDO = decltype(restage(typename CollectiveMmaDOV::SmemLayoutA{}, _1{}));
using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int<kStagesComputeSmem>{}));
using SmemLayoutLSE = Layout<Shape<TileShapeQ, _1>>;
using SmemLayoutSumOdO = Layout<Shape<TileShapeQ, _1>>;
using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{}));
using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{}));
using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int<kStagesComputeSmem>{}));
using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{}));
using SmemLayoutP = decltype(restage(typename CollectiveMmaPDO::SmemLayoutA{}, _1{}));
using SmemLayoutPT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, _1{}));
using TileShapeDQ = _32;
using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ
>());
using SmemShapeDQ = Shape<TileShapeQ, TileShapeDQ, Int<kStagesReduceTmaStore>>;
using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{}));
struct TensorStorage {
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutK>> smem_k;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKT>> smem_k_t;
};
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutV>> smem_v;
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQT>> smem_q_t;
};
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDO>> smem_do;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDOT>> smem_do_t;
};
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutDST>> smem_ds_t;
};
union{
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutPT>> smem_p_t;
};
alignas(1024) cute::array<ElementAcc, cute::cosize_v<SmemLayoutDQ>> smem_dq;
alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutLSE>> smem_lse;
alignas(16) cute::array<ElementAcc, cute::cosize_v<SmemLayoutSumOdO>> smem_sum_odo;
};
static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
struct SharedStorage {
TensorStorage tensors;
PipelineStorage pipelines;
uint32_t tmem_base_ptr;
};
// this is tight enough that it won't work with sizeof due to padding for alignment
static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t);
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
using TensorStride = TensorStrideContiguousK; // S D (H B)
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
struct MainloopArguments {
const Element* ptr_q;
TensorStride stride_q;
const Element* ptr_k;
TensorStride stride_k;
const Element* ptr_v;
TensorStride stride_v;
const Element* ptr_do;
TensorStride stride_do;
const ElementAcc* ptr_lse;
RowTensorStride stride_lse;
const ElementAcc* ptr_sum_odo;
RowTensorStride stride_sum_odo;
ElementAcc* ptr_dq_acc;
TensorStride stride_dq_acc;
ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{});
};
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaDOV::Params::TMA_B;
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_DO = typename CollectiveMmaDOV::Params::TMA_A;
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}),
SmemLayoutDQ{}(_, _, _0{})
));
struct MainloopParams {
TMA_K tma_load_k;
TMA_V tma_load_v;
TMA_Q tma_load_q;
TMA_DO tma_load_do;
TMA_DQ tma_red_dq;
};
struct EpilogueArguments {
Element* ptr_dk;
TensorStride stride_dk;
Element* ptr_dv;
TensorStride stride_dv;
};
struct Arguments {
ProblemShape problem_shape;
MainloopArguments mainloop;
EpilogueArguments epilogue;
KernelHardwareInfo hw_info;
};
struct Params {
ProblemShape problem_shape;
MainloopArguments mainloop;
MainloopParams mainloop_params;
EpilogueArguments epilogue;
KernelHardwareInfo hw_info;
};
static bool can_implement(Arguments const& args) {
auto [Q, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0 || D_VO <= 0) {
return false;
}
if (D % Alignment != 0 || D_VO % Alignment != 0) {
return false;
}
return true;
}
static Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return Status::kSuccess;
}
static Params to_underlying_arguments(Arguments const& args, void*) {
auto [Q_, K_, D, D_VO, HB] = args.problem_shape;
int Q = Q_;
int K = K_;
if constexpr (is_variable_length_v<decltype(Q_)>) {
Q = Q_.total_length;
}
if constexpr (is_variable_length_v<decltype(K_)>) {
K = K_.total_length;
}
auto params_kq = CollectiveMmaQK::to_underlying_arguments(
make_shape(Q, K, D, HB),
typename CollectiveMmaQK::Arguments {
args.mainloop.ptr_q, args.mainloop.stride_q,
args.mainloop.ptr_k, args.mainloop.stride_k,
}, /*workspace=*/nullptr);
auto params_vdo = CollectiveMmaDOV::to_underlying_arguments(
make_shape(Q, K, D_VO, HB),
typename CollectiveMmaDOV::Arguments {
args.mainloop.ptr_do, args.mainloop.stride_do,
args.mainloop.ptr_v, args.mainloop.stride_v,
}, /*workspace=*/nullptr);
TMA_DQ tma_red_dq = make_tma_copy(
SM90_TMA_REDUCE_ADD{},
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc),
SmemLayoutDQ{}(_, _, _0{})
);
return Params{
args.problem_shape,
args.mainloop,
MainloopParams{
params_kq.tma_load_b,
params_vdo.tma_load_b,
params_kq.tma_load_a,
params_vdo.tma_load_a,
tma_red_dq
},
args.epilogue,
args.hw_info
};
}
template<class T>
static CUTLASS_DEVICE auto quantize(T const& input) {
constexpr int AlignmentS = 4;
auto output = make_tensor<Element>(shape(input));
auto input_vec = recast<Array<ElementAcc, AlignmentS>>(input);
auto output_vec = recast<Array<Element, AlignmentS>>(output);
cutlass::NumericArrayConverter<Element, ElementAcc, AlignmentS> epilogue_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(input_vec); i++) {
output_vec(i) = epilogue_op(input_vec(i));
}
return output;
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void load(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state,
PipelineLoadComputeLSE& pipeline_load_compute_lse,
typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state,
PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
using X = Underscore;
uint16_t mcast_mask = 0;
auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB));
auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB));
auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in);
auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in);
auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in);
auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in);
auto gK = local_tile(mK, TileShapeQK{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto gQ = local_tile(mQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gV = local_tile(mV, TileShapeDOV{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto gDO = local_tile(mDO, TileShapeDOV{}, make_coord(_,_,_), Step<_1, X, _1>{});
ThrMMA cta_mma_kq = TiledMmaQK{}.get_slice(_0{});
ThrMMA cta_mma_vdo = TiledMmaDOV{}.get_slice(_0{});
auto tSTgK = cta_mma_kq.partition_B(gK);
auto tSTgQ = cta_mma_kq.partition_A(gQ);
auto tDPTgV = cta_mma_vdo.partition_B(gV);
auto tDPTgDO = cta_mma_vdo.partition_A(gDO);
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});
auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});
auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});
auto [tKgK_mkl, tKsK] = tma_partition(
mainloop_params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSTgK));
auto [tQgQ_mkl, tQsQ] = tma_partition(
mainloop_params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ));
auto [tVgV_mkl, tVsV] = tma_partition(
mainloop_params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tDPTgV));
auto [tDOgDO_mkl, tDOsDO] = tma_partition(
mainloop_params.tma_load_do, _0{}, make_layout(_1{}),
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
// set up lse and sum_odo
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK);
// load K
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask),
tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch),
tKsK(_, _0{})
);
}
// load Q
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
tQsQ(_, pipeline_load_mma_q_producer_state.index())
);
}
++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE
// 32 threads loading kLoadPerThread * 32 values of 32b each
int thread_idx = threadIdx.x % NumThreadsPerWarp;
int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread;
int gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;
auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse);
for (int i = 0; i < kLoadPerThread; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
// load V
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask),
tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch),
tVsV(_, _0{})
);
}
// load dO
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
tDOsDO(_, pipeline_load_mma_do_producer_state.index())
);
}
++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;
auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo);
for (int i = 0; i < kLoadPerThread; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
iter_count -= 1;
iter_index += 1;
while (iter_count > 0) {
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
// load Q
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
tQsQ(_, pipeline_load_mma_q_producer_state.index())
);
}
++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;
for (int i = 0; i < kLoadPerThread; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_lse.begin() + smem_idx + i,
&mLSE(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
// load dO
if (cute::elect_one_sync()) {
cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
tDOsDO(_, pipeline_load_mma_do_producer_state.index())
);
}
++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread;
for (int i = 0; i < kLoadPerThread; i++) {
cutlass::arch::cp_async_zfill<4>(
shared_tensors.smem_sum_odo.begin() + smem_idx + i,
&mSumOdO(gmem_idx + i, blk_coord_batch),
gmem_idx + i < Q
);
}
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state;
iter_count -= 1;
iter_index += 1;
}
}
template<class BlkCoord, class ProblemShape_>
CUTLASS_DEVICE void mma(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});
auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{});
auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{});
auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{});
auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{});
auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{});
auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{});
auto sP = make_tensor(make_smem_ptr(shared_tensors.smem_p.begin()), SmemLayoutP{});
auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{});
Tensor tSTrK = TiledMmaQK::make_fragment_B(sK);
Tensor tSTrQ = TiledMmaQK::make_fragment_A(sQ);
Tensor tDPTrV = TiledMmaDOV::make_fragment_B(sV);
Tensor tDPTrDO = TiledMmaDOV::make_fragment_A(sDO);
Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS);
Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT);
Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST);
Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT);
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP);
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
TiledMmaQK tiled_mma_qk;
TiledMmaDOV tiled_mma_dov;
TiledMmaDSK tiled_mma_dsk;
TiledMmaDSQ tiled_mma_dsq;
TiledMmaPDO tiled_mma_pdo;
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero;
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero;
Tensor tSTtST = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{}));
tSTtST.data() = TmemAllocation::kS;
Tensor tDPTtDPT = partition_fragment_C(tiled_mma_dov, select<0,1>(TileShapeDOV{}));
tDPTtDPT.data() = TmemAllocation::kDP;
Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{}));
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{}));
tDKtDK.data() = TmemAllocation::kDK;
Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{}));
tDVtDV.data() = TmemAllocation::kDV;
auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state;
pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);
pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);
// S = Q*K
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {
cute::gemm(tiled_mma_qk,
tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),
tSTrK(_,_,k_block,_0{}),
tSTtST);
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One;
}
++pipeline_load_mma_q_consumer_state;
pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);
++pipeline_mma_compute_s_producer_state;
pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);
pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);
pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);
// dP = dO*V
tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {
cute::gemm(tiled_mma_dov,
tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDPTrV(_,_,k_block,_0{}),
tDPTtDPT);
tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);
++pipeline_mma_compute_dp_producer_state;
pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);
// dV = P*dO
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {
cute::gemm(tiled_mma_pdo,
tDVrP(_,_,k_block,_0{}),
tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDVtDV);
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);
++pipeline_compute_mma_p_consumer_state;
pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);
++pipeline_load_mma_do_consumer_state;
iter_count -= 1;
// in tmem, S & P overlap
// and dP and dQ overlap
// so we need to acquire dQ and dP at the same time
while (iter_count > 0) {
pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state);
pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state);
// S = Q*K
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) {
cute::gemm(tiled_mma_qk,
tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()),
tSTrK(_,_,k_block,_0{}),
tSTtST);
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One;
}
++pipeline_load_mma_q_consumer_state;
pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state);
++pipeline_mma_compute_s_producer_state;
pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);
// we need to acquire dP here, because tmem dQ == tmem dP
pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state);
// dQ = dS*K
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {
cute::gemm(tiled_mma_dsk,
tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDQrKT(_,_,k_block,_0{}),
tDQtDQ);
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);
++pipeline_mma_reduce_dq_producer_state;
// dK = dS*Q
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {
cute::gemm(tiled_mma_dsq,
tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),
tDKtDK);
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);
++pipeline_load_mma_q_release_state;
pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);
++pipeline_compute_mma_ds_consumer_state;
// we grab dq here, because in tmem dq == dp
pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state);
pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state);
// dP = dO*V
tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) {
cute::gemm(tiled_mma_dov,
tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDPTrV(_,_,k_block,_0{}),
tDPTtDPT);
tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state);
++pipeline_mma_compute_dp_producer_state;
pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state);
// dV = P*dO
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) {
cute::gemm(tiled_mma_pdo,
tDVrP(_,_,k_block,_0{}),
tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()),
tDVtDV);
tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state);
++pipeline_compute_mma_p_consumer_state;
pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state);
++pipeline_load_mma_do_consumer_state;
iter_count -= 1;
}
// signal to the epilogue that dV is ready
pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);
pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);
++pipeline_mma_compute_dkdv_producer_state;
pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state);
pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state);
// dK = dS*Q
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) {
cute::gemm(tiled_mma_dsq,
tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()),
tDKtDK);
tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One;
}
// signal to epilgue that dK is ready
pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state);
++pipeline_mma_compute_dkdv_producer_state;
// we've already acquired mma_reduce_dq in the loop
// dQ = dS*K
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) {
cute::gemm(tiled_mma_dsk,
tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()),
tDQrKT(_,_,k_block,_0{}),
tDQtDQ);
tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state);
++pipeline_mma_reduce_dq_producer_state;
pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state);
++pipeline_load_mma_q_release_state;
pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state);
++pipeline_compute_mma_ds_consumer_state;
}
template<class TensorG, class TensorR, class TensorC, class TensorShape>
CUTLASS_DEVICE void store(
TensorG gmem,
TensorR const& regs,
TensorC const& coord,
TensorShape const& tensor_shape) {
//TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version.
// Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
regs.layout()
);
auto thr_copy = copy_op.get_slice(_0{});
Tensor quantized_regs = quantize(regs);
auto tCg = thr_copy.partition_D(gmem);
auto tCr = thr_copy.partition_S(quantize(regs));
auto tCc = thr_copy.partition_D(coord);
constexpr int R = decltype(tCr.layout())::rank;
auto tCg_v = group_modes<1, R>(tCg);
auto tCr_v = group_modes<1, R>(tCr);
auto tCc_v = group_modes<1, R>(tCc);
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
for (int i = 0; i < size(tCp_v); ++i) {
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
}
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void epilogue_clear(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDK = domain_offset(
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
);
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) {
if (elem_less(cDK(i), select<1,2>(problem_shape))) {
gDK(i) = Element(0);
}
}
for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) {
if (elem_less(cDV(i), select<1,3>(problem_shape))) {
gDV(i) = Element(0);
}
}
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void epilogue(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{});
tDKtDK.data() = TmemAllocation::kDK;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDK = domain_offset(
make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapeDSQ{}))
);
constexpr int kNumWarpgroups = kNumComputeWarps / 4;
int dp_idx = threadIdx.x % 128;
int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;
auto split_wg = [&](auto const& t) {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, make_coord(wg_idx, _));
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, _, make_coord(wg_idx, _));
}
};
auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK);
auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx);
Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK));
Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK));
Tensor tTR_rDK = make_tensor<ElementAcc>(shape(tTR_cDK));
Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK));
auto tDVtDV = partition_fragment_C(TiledMmaPDO{}, select<0,1>(TileShapePDO{}))(make_coord(_,_),_0{},_0{});
tDVtDV.data() = TmemAllocation::kDV;
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
Tensor cDV = domain_offset(
make_coord(blk_coord_k * TileShapeK{}, _0{}),
make_identity_tensor(take<0,2>(TileShapePDO{}))
);
auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV);
auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx);
Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV));
Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV));
Tensor tTR_rDV = make_tensor<ElementAcc>(shape(tTR_cDV));
Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV));
pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);
// load tDVtDV
cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV);
// store tDVgDV
store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape));
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);
++pipeline_mma_compute_dkdv_consumer_state;
pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state);
// load tDKtDK
cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rDK); i++) {
tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i);
}
// store tDKgDK
store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape));
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);
++pipeline_mma_compute_dkdv_consumer_state;
}
template<class BlkCoord, class BlkOffset, class ProblemShape_>
CUTLASS_DEVICE void compute(
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
TensorStorage& shared_tensors,
PipelineLoadComputeLSE& pipeline_load_compute_lse,
typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state,
PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state,
PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
// in tmem, S & P overlap
// and dP and dQ overlap
// there are two compute wg's that cooperatively compute softmax
// they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc
auto load_op = SM100_TMEM_LOAD_16dp32b32x{};
Tensor tSTtST = partition_fragment_C(TiledMmaQK{}, select<0,1>(TileShapeQK{}))(make_coord(_,_),_0{},_0{});
tSTtST.data() = TmemAllocation::kS;
Tensor tDPTtDPT = partition_fragment_C(TiledMmaDOV{}, select<0,1>(TileShapeDOV{}))(make_coord(_,_),_0{},_0{});
tDPTtDPT.data() = TmemAllocation::kDP;
Tensor cST = make_identity_tensor(take<0,2>(TileShapeQK{}));
Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeDOV{}));
Tensor cPT = make_identity_tensor(take<0,2>(TileShapeQK{}));
constexpr int kNumWarpgroups = kNumComputeWarps / 4;
int dp_idx = threadIdx.x % 128;
int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128;
auto tiled_t2r = make_tmem_copy(load_op, tSTtST);
auto thread_t2r = tiled_t2r.get_slice(dp_idx);
auto split_wg = [&](auto const& t) {
if constexpr (decltype(size<1>(t))::value > 1) {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t))));
return p(_, make_coord(wg_idx, _), _);
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int<kNumWarpgroups>{}, size<1>(t) / Int<kNumWarpgroups>{}), size<2>(t), size<3>(t))));
return p(_, make_coord(wg_idx, _), _, _);
}
}
else {
if constexpr (decltype(rank(t))::value == 3) {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int<kNumWarpgroups>{}, size<2>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, make_coord(wg_idx, _));
}
else {
auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int<kNumWarpgroups>{}, size<3>(t) / Int<kNumWarpgroups>{}))));
return p(_, _, _, make_coord(wg_idx, _));
}
}
};
Tensor tTR_cST_p = thread_t2r.partition_D(cST);
Tensor tTR_cST = split_wg(tTR_cST_p);
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
Tensor tTR_cPT_p = thread_t2r.partition_D(cPT);
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT));
Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{});
Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{});
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
int last_iter = iter_count - 1 + iter_index;
CUTLASS_PRAGMA_NO_UNROLL
while (iter_count > 0) {
// wait for S and P
pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state);
pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state);
// wait for LSE
pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state);
auto dispatch_bool = [](bool b, auto fn) {
if (b) {
fn(cute::true_type{});
}
else {
fn(cute::false_type{});
}
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
int kv_left = get<1>(blk_coord) * TileShapeK{};
int kv_right = kv_left + TileShapeK{} - 1;
int q_left = iter_index * TileShapeQ{} + offset;
int q_right = q_left + TileShapeQ{} - 1;
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
}
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
// compute P = softmax(S, LSE)
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
if constexpr (decltype(is_masked_tile)::value) {
Mask{}.apply_mask(tTR_rST, [&](int i) {
auto c_transpose = tTR_cST(i);
return make_coord(get<0>(c_transpose) + iter_index * TileShapeQ{}, get<1>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
}, problem_shape);
}
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
float2 softmax_scale_log2_e;
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rST); i += 2) {
float2 acc;
float2 lse;
float2 out;
acc.x = tTR_rST(i);
acc.y = tTR_rST(i + 1);
lse.x = sLSE(get<0>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index());
lse.y = sLSE(get<0>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index());
cute::fma(out, softmax_scale_log2_e, acc, lse);
tTR_rST(i) = ::exp2f(out.x);
tTR_rST(i+1) = ::exp2f(out.y);
}
auto tRT_rST = quantize(tTR_rST);
Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})
(_, _, _, pipeline_compute_mma_p_producer_state.index());
cutlass::arch::fence_view_async_tmem_load();
cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransformBarrier
).arrive_and_wait();
auto sP_pi = as_position_independent_swizzle_tensor(sP);
auto thread_layout = make_ordered_layout(
make_shape(_64{}, _32{}, _2{}, _2{}),
make_stride(_3{}, _0{}, _1{}, _2{})
);
auto sP_pi_slice_p = sP_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape(tTR_cPT_p)));
auto sP_pi_slice = split_wg(sP_pi_slice_p);
copy_aligned(tRT_rST, sP_pi_slice);
});
// notify for P
cutlass::arch::fence_view_async_shared();
pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state);
++pipeline_compute_mma_p_producer_state;
// release S
pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state);
++pipeline_mma_compute_s_consumer_state;
// release LSE
pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state);
++pipeline_load_compute_lse_consumer_state;
// wait for OdO
pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state);
// wait for dP
pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state);
// wait for dS
// in principle, we could defer waiting for dS, and move in the freeing of dP
// however, that would force us to keep dS in registers longer
pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state);
// compute dS = dsoftmax(P, dP, sum_OdO)
cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rDPT); i += 2) {
float2 st;
st.x = tTR_rST(i);
st.y = tTR_rST(i+1);
float2 dpt;
dpt.x = tTR_rDPT(i);
dpt.y = tTR_rDPT(i+1);
float2 odo;
odo.x = sSumOdO(get<0>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index());
odo.y = sSumOdO(get<0>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index());
float2 dif;
// sum odo is negated during preprocess
cute::add(dif, dpt, odo);
float2 out;
cute::mul(out, dif, st);
tTR_rDPT(i) = out.x;
tTR_rDPT(i+1) = out.y;
}
auto tTR_rDST = quantize(tTR_rDPT);
// release dP
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state);
++pipeline_mma_compute_dp_consumer_state;
Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds_t.begin()), SmemLayoutDST{})
(_, _, _, pipeline_compute_mma_ds_producer_state.index());
auto thread_layout = make_ordered_layout(
make_shape(_64{}, _32{}, _2{}, _2{}),
make_stride(_3{}, _0{}, _1{}, _2{})
);
auto sDS_pi = as_position_independent_swizzle_tensor(sDS);
auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape (tTR_cDPT_p)));
auto sDS_pi_slice = split_wg(sDS_pi_slice_p);
copy_aligned(tTR_rDST, sDS_pi_slice);
// notify for dS
cutlass::arch::fence_view_async_shared();
pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state);
++pipeline_compute_mma_ds_producer_state;
// release OdO
pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state);
++pipeline_load_compute_sum_odo_consumer_state;
iter_count -= 1;
iter_index += 1;
}
epilogue(
blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
);
}
template<class BlkCoord, class ProblemShape_>
CUTLASS_DEVICE void reduce(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
TensorStorage& shared_tensors,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
PipelineReduceTmaStore& pipeline_reduce_tma_store,
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
using X = Underscore;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
// must match TileShapeDQ
auto load_op = SM100_TMEM_LOAD_16dp32b16x{};
auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{});
tDQtDQ.data() = TmemAllocation::kDQ;
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
auto gDQ = local_tile(mDQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{})
(_, _, _, _0{}, blk_coord_batch);
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{});
int thread_idx = threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp);
auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ);
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_cDQ = thread_t2r.partition_D(cDQ);
Tensor tTR_gDQ = thread_t2r.partition_D(gDQ);
Tensor tTR_sDQ = thread_t2r.partition_D(sDQ);
Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ);
auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{});
Tensor tDQsDQ = block_tma.partition_S(sDQ);
Tensor tDQcDQ = block_tma.partition_S(cDQ);
Tensor tDQgDQ = block_tma.partition_D(gDQ);
int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0;
while (iter_count > 0) {
pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state);
Tensor tTR_rDQ = make_tensor<ElementAcc>(shape(tTR_cDQ));
// load dQ from tmem to rmem
cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ);
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state);
++pipeline_mma_reduce_dq_consumer_state;
// we don't have enough smem to dump it all to smem, so we do it in stages
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<2>(tTR_cDQ); i++) {
if (lane_predicate) {
pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state);
}
// wait in all threads for the acquire to complete
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier
).arrive_and_wait();
cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index()));
// wait for the stores to all be visible to the TMA
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier
).arrive_and_wait();
if (lane_predicate) {
// launch tma store
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index));
pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);
}
++pipeline_reduce_tma_store_producer_state;
}
iter_count -= 1;
iter_index += 1;
}
}
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
if (role == WarpRole::Load && lane_predicate) {
prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor());
}
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
int initializing_warp = 0;
typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params;
if (role == WarpRole::Load) {
pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer;
}
if (role == WarpRole::Mma) {
pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer;
}
pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load);
// Also loads K in the first iteration
pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ;
pipeline_load_mma_q_params.initializing_warp = initializing_warp++;
PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params;
if (role == WarpRole::Load) {
pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer;
}
if (role == WarpRole::Mma) {
pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer;
}
pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load);
// Also loads V in the first iteration
pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO;
pipeline_load_mma_do_params.initializing_warp = initializing_warp++;
PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params;
if (role == WarpRole::Load) {
pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer;
}
pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp;
pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;
pipeline_load_compute_lse_params.initializing_warp = initializing_warp++;
PipelineLoadComputeLSE pipeline_load_compute_lse(
shared_storage.pipelines.load_compute_lse,
pipeline_load_compute_lse_params,
/*barrier init*/ cute::true_type{});
typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params;
if (role == WarpRole::Load) {
pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer;
}
pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp;
pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp;
pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++;
PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo(
shared_storage.pipelines.load_compute_sum_odo,
pipeline_load_compute_sum_odo_params,
/*barrier init*/ cute::true_type{});
typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer;
}
pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_s_params.initializing_warp = initializing_warp++;
PipelineMmaComputeS pipeline_mma_compute_s(
shared_storage.pipelines.mma_compute_s,
pipeline_mma_compute_s_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer;
}
pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++;
PipelineMmaComputeDP pipeline_mma_compute_dp(
shared_storage.pipelines.mma_compute_dp,
pipeline_mma_compute_dp_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params;
if (role == WarpRole::Mma) {
pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer;
}
if (role == WarpRole::Reduce) {
pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer;
}
pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++;
PipelineMmaReduceDQ pipeline_mma_reduce_dq(
shared_storage.pipelines.mma_reduce_dq,
pipeline_mma_reduce_dq_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params;
if (role == WarpRole::Mma) {
pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer;
}
if (role == WarpRole::Compute) {
pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer;
}
pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_compute_mma_p_params.consumer_arv_count = 1;
pipeline_compute_mma_p_params.initializing_warp = initializing_warp++;
PipelineComputeMmaP pipeline_compute_mma_p(
shared_storage.pipelines.compute_mma_p,
pipeline_compute_mma_p_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params;
if (role == WarpRole::Mma) {
pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer;
}
if (role == WarpRole::Compute) {
pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer;
}
pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_compute_mma_ds_params.consumer_arv_count = 1;
pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++;
PipelineComputeMmaDS pipeline_compute_mma_ds(
shared_storage.pipelines.compute_mma_ds,
pipeline_compute_mma_ds_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params;
if (role == WarpRole::Mma) {
pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer;
}
if (role == WarpRole::Compute) {
pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer;
}
pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp;
pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++;
PipelineMmaComputeDKDV pipeline_mma_compute_dkdv(
shared_storage.pipelines.mma_compute_dkdv,
pipeline_mma_compute_dkdv_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
PipelineReduceTmaStore pipeline_reduce_tma_store;
TmemAllocator tmem_allocator;
pipeline_init_arrive_relaxed(size(ClusterShape{}));
pipeline_load_mma_q.init_masks(ClusterShape{});
pipeline_load_mma_do.init_masks(ClusterShape{});
pipeline_mma_compute_s.init_masks(ClusterShape{});
pipeline_mma_compute_dp.init_masks(ClusterShape{});
pipeline_mma_reduce_dq.init_masks(ClusterShape{});
pipeline_compute_mma_p.init_masks(ClusterShape{});
pipeline_compute_mma_ds.init_masks(ClusterShape{});
pipeline_mma_compute_dkdv.init_masks(ClusterShape{});
typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state;
typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state;
typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state;
typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state;
typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state;
typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state;
typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state;
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state<decltype(pipeline_load_compute_sum_odo)>();
auto pipeline_mma_compute_s_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_s)>();
auto pipeline_mma_compute_dp_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dp)>();
auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state<decltype(pipeline_mma_reduce_dq)>();
auto pipeline_compute_mma_p_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_p)>();
auto pipeline_compute_mma_ds_producer_state = make_producer_start_state<decltype(pipeline_compute_mma_ds)>();
auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state<decltype(pipeline_mma_compute_dkdv)>();
auto pipeline_reduce_tma_store_producer_state = make_producer_start_state<decltype(pipeline_reduce_tma_store)>();
pipeline_init_wait(size(ClusterShape{}));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));
auto [problem_shape, blk_offset] = apply_variable_length_offset(
params.problem_shape,
blk_coord
);
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;
}
iter_count -= iter_start;
if (iter_count <= 0) {
epilogue_clear(
blk_coord,
blk_offset,
problem_shape,
params.mainloop,
params.epilogue
);
return;
}
if (role == WarpRole::Load) {
warpgroup_reg_set<RegisterAllocation::kLoad>();
load(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
);
}
else if (role == WarpRole::Mma) {
warpgroup_reg_set<RegisterAllocation::kMma>();
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
mma(
blk_coord,
problem_shape,
iter_start,
iter_count,
params.mainloop,
shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state,
pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state,
pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state
);
}
else if (role == WarpRole::Compute) {
warpgroup_reg_set<RegisterAllocation::kCompute>();
compute(
blk_coord,
blk_offset,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.epilogue,
shared_storage.tensors,
pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state,
pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state,
pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state,
pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state
);
cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
).arrive_and_wait();
if (warp_idx % kNumComputeWarps == 0) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
else if (role == WarpRole::Reduce) {
warpgroup_reg_set<RegisterAllocation::kReduce>();
reduce(
blk_coord,
problem_shape,
iter_start,
iter_count,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state,
pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state
);
pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state);
}
else {
warpgroup_reg_set<RegisterAllocation::kEmpty>();
/* no-op */
}
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static dim3 get_grid_shape(Params const& params) {
auto [Q, K, D, D_VO, HB] = params.problem_shape;
auto [H, B] = HB;
dim3 grid(ceil_div(K, TileShapeK{}), H, B);
return grid;
}
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/kernel_hardware_info.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/arch/tmem_allocator_sm100.hpp"
#include "kernel/fmha_options.hpp"
#include "kernel/fmha_tile_scheduler.hpp"
#include "kernel/fmha_causal_tile_scheduler.hpp"
#include "collective/fmha_fusion.hpp"
#include "collective/fmha_common.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
using namespace cutlass::fmha::collective;
struct Sm100FmhaCtxKernelWarpspecializedSchedule {
enum class WarpRole {
Softmax0,
Softmax1,
Correction,
MMA,
Load,
Epilogue,
Empty
};
static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
int wg_idx = warp_idx / 4; // warp_idx
if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3
if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7
if (wg_idx == 2) return WarpRole::Correction; // 8 - 11
if (warp_idx == 12) return WarpRole::MMA; // 12
if (warp_idx == 13) return WarpRole::Load; // 13
if (warp_idx == 14) return WarpRole::Epilogue; // 14
return WarpRole::Empty; // 15
}
static const int NumWarpsSoftmax = 4;
static const int NumWarpsCorrection = 4;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
static const bool kDebugUsingPrintf = false;
static const int NumRegsSoftmax = 192;
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsEmpty = 24;
static const int NumWarps = 16;
};
struct Sm100MlaFwdCtxKernelWarpspecializedSchedule {
enum class WarpRole {
Softmax0,
Softmax1,
Correction,
MMA,
Load,
Epilogue,
Empty
};
static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
int wg_idx = warp_idx / 4; // warp_idx
if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3
if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7
if (wg_idx == 2) return WarpRole::Correction; // 8 - 11
if (warp_idx == 12) return WarpRole::MMA; // 12
if (warp_idx == 13) return WarpRole::Load; // 13
if (warp_idx == 14) return WarpRole::Epilogue; // 14
return WarpRole::Empty; // 15
}
static const int NumWarpsSoftmax = 4;
static const int NumWarpsCorrection = 4;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
static const bool kDebugUsingPrintf = false;
static const int NumRegsSoftmax = 184;
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsEmpty = 24;
static const int NumWarps = 16;
};
template<
class ProblemShapeIn,
class CollectiveMainloop,
class CollectiveEpilogue,
class TileScheduler,
class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule
>
struct Sm100FmhaFwdKernelTmaWarpspecialized {
using TileShape = typename CollectiveMainloop::TileShape;
using ProblemShape = ProblemShapeIn;
using WarpRole = typename KernelSchedule::WarpRole;
constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
return KernelSchedule::warp_idx_to_WarpRole(warp_idx);
}
static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax;
static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection;
static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue;
static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad;
static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue);
static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad);
static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax;
static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection;
static const int NumRegsOther = KernelSchedule::NumRegsOther;
static const int NumRegsEmpty = 24;
static const int NumWarps = KernelSchedule::NumWarps;
static constexpr bool IsMla = std::is_same_v<KernelSchedule, Sm100MlaFwdCtxKernelWarpspecializedSchedule>;
using ClusterShape = typename CollectiveMainloop::ClusterShape;
using TmemAllocator = cute::TMEM::Allocator1Sm;
struct SharedStorage {
using UnionType = union {
typename CollectiveMainloop::TensorStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue;
};
using StructType = struct {
typename CollectiveMainloop::TensorStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue;
};
static constexpr bool IsPersistent = std::is_same_v<TileScheduler, PersistentTileScheduler> || std::is_same_v<TileScheduler, CausalPersistentTileScheduler>;
using MainloopEpilogueStorage = std::conditional_t<IsPersistent,
std::conditional_t<IsMla,
std::conditional_t<CollectiveMainloop::IsOrderLoadEpilogue, UnionType, StructType>,
StructType>,
UnionType>;
MainloopEpilogueStorage mainloop_epilogue;
struct PipelineStorage {
alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q;
alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv;
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0;
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1;
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr;
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr;
alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr;
alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi;
alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01;
} pipelines;
uint32_t tmem_base_ptr;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
struct Arguments {
ProblemShape problem_shape;
typename CollectiveMainloop::Arguments mainloop;
typename CollectiveEpilogue::Arguments epilogue;
cutlass::KernelHardwareInfo hw_info;
};
struct Params {
ProblemShape problem_shape;
typename CollectiveMainloop::Params mainloop;
typename CollectiveEpilogue::Params epilogue;
typename TileScheduler::Params tile_scheduler;
};
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp;
using ArchTag = cutlass::arch::Sm100;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static bool can_implement(Arguments const& args) {
return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
}
static dim3 get_grid_shape(Params const& params) {
return TileScheduler::get_grid_shape(params.tile_scheduler);
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return Params{
args.problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{})
};
}
CUTLASS_DEVICE auto apply_batch(const Params &params, ProblemShape const& problem_shape, int batch_idx) {
return apply_variable_length(params.problem_shape, batch_idx);
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
TileScheduler tile_scheduler{params.tile_scheduler};
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_WarpRole(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
if (role == WarpRole::Load && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
}
if (role == WarpRole::Epilogue && lane_predicate) {
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
auto get_epilogue_storage = [&]() {
if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {
return reinterpret_cast<typename CollectiveEpilogue::TensorStorage *>(shared_storage.mainloop_epilogue.mainloop.smem_o.data());
} else {
return &shared_storage.mainloop_epilogue.epilogue;
}
};
typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage();
typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params;
if (role == WarpRole::Load) {
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer;
}
if (role == WarpRole::MMA) {
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer;
}
pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load);
pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ;
typename CollectiveMainloop::PipelineQ pipeline_load_q(
shared_storage.pipelines.load_q,
pipeline_load_q_params,
ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{});
typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params;
if (role == WarpRole::Load) {
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer;
}
if (role == WarpRole::MMA) {
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer;
}
pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load);
pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK;
typename CollectiveMainloop::PipelineKV pipeline_load_kv(
shared_storage.pipelines.load_kv,
pipeline_load_kv_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params;
if (role == WarpRole::MMA) {
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
}
if (role == WarpRole::Softmax0) {
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
}
pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineS pipeline_mma_s0(
shared_storage.pipelines.mma_s0,
pipeline_mma_s0_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params;
if (role == WarpRole::MMA) {
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
}
if (role == WarpRole::Softmax1) {
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
}
pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineS pipeline_mma_s1(
shared_storage.pipelines.mma_s1,
pipeline_mma_s1_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params;
if (role == WarpRole::Softmax0) {
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
}
if (role == WarpRole::Correction) {
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
}
pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineC pipeline_s0_corr(
shared_storage.pipelines.s0_corr,
pipeline_s0_corr_params,
/*barrier init*/ cute::true_type{});
typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params;
if (role == WarpRole::Softmax1) {
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
}
if (role == WarpRole::Correction) {
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
}
pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineC pipeline_s1_corr(
shared_storage.pipelines.s1_corr,
pipeline_s1_corr_params,
/*barrier init*/ cute::true_type{});
typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params;
if (role == WarpRole::MMA) {
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer;
}
if (role == WarpRole::Correction) {
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer;
}
pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineO pipeline_mma_corr(
shared_storage.pipelines.mma_corr,
pipeline_mma_corr_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params;
if (role == WarpRole::Correction) {
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer;
}
if (role == WarpRole::Epilogue) {
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;
}
pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::PipelineE pipeline_corr_epi(
shared_storage.pipelines.corr_epi,
pipeline_corr_epi_params,
/*barrier init*/ cute::true_type{});
typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01;
params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0;
params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
typename CollectiveMainloop::OrderBarrierSoftmax order_s01(
shared_storage.pipelines.order_s01, params_order_s01);
TmemAllocator tmem_allocator;
__syncthreads();
pipeline_load_q.init_masks(ClusterShape{});
pipeline_load_kv.init_masks(ClusterShape{});
pipeline_mma_s0.init_masks(ClusterShape{});
pipeline_mma_s1.init_masks(ClusterShape{});
pipeline_mma_corr.init_masks(ClusterShape{});
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state;
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineQ>();
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state;
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineKV>();
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state;
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state;
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state;
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state;
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state;
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineE>();
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state;
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
CollectiveMainloop mainloop;
CollectiveEpilogue epilogue{params.epilogue};
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
warpgroup_reg_set<NumRegsSoftmax>();
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto logical_problem_shape = apply_batch(params,
params.problem_shape, get<2,1>(blk_coord));
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
bool is_softmax_0 = role == WarpRole::Softmax0;
mainloop.softmax(
is_softmax_0 ? 0 : 1, blk_coord,
params.mainloop, logical_problem_shape,
is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1,
is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state,
is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr,
is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state,
order_s01
);
}
}
else if (role == WarpRole::Correction) {
cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto logical_problem_shape = apply_batch(params,
params.problem_shape, get<2,1>(blk_coord));
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
mainloop.correction_empty(
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
epilogue_storage,
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
);
continue;
}
mainloop.correction(
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
epilogue_storage,
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
);
}
if constexpr (NumWarpsEpilogue == 0) {
static_assert(NumWarpsCorrection == 1);
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
else if (role == WarpRole::MMA) {
warpgroup_reg_set<NumRegsOther>();
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto logical_problem_shape = apply_batch(params,
params.problem_shape, get<2,1>(blk_coord));
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
mainloop.mma(
blk_coord,
params.mainloop, logical_problem_shape,
shared_storage.mainloop_epilogue.mainloop,
pipeline_load_q, pipeline_load_q_consumer_state,
pipeline_load_kv, pipeline_load_kv_consumer_state,
pipeline_mma_s0, pipeline_mma_s0_producer_state,
pipeline_mma_s1, pipeline_mma_s1_producer_state,
pipeline_mma_corr, pipeline_mma_corr_producer_state
);
}
}
else if (role == WarpRole::Load) {
warpgroup_reg_set<NumRegsOther>();
if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto logical_problem_shape = apply_batch(params,
params.problem_shape, get<2,1>(blk_coord));
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
mainloop.load(
blk_coord, logical_problem_shape,
params.mainloop, params.problem_shape,
shared_storage.mainloop_epilogue.mainloop,
pipeline_load_q, pipeline_load_q_producer_state,
pipeline_load_kv, pipeline_load_kv_producer_state
);
}
}
else if (role == WarpRole::Epilogue) {
warpgroup_reg_set<NumRegsOther>();
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto logical_problem_shape = apply_batch(params,
params.problem_shape, get<2,1>(blk_coord));
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
continue;
}
epilogue.store(
blk_coord, logical_problem_shape,
params.epilogue, params.problem_shape,
epilogue_storage,
pipeline_corr_epi, pipeline_corr_epi_consumer_state
);
}
static_assert(NumWarpsEpilogue <= 1);
if constexpr (NumWarpsEpilogue == 1) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
else if (role == WarpRole::Empty) {
warpgroup_reg_set<NumRegsEmpty>();
/* no-op, donate regs and exit */
}
}
};
} // namespace cutlass::fmha::kernel
#include <torch/python.h>
void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor o, at::Tensor lse,
int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen);
void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fwd", &FMHACutlassSM100FwdRun);
m.def("bwd", &FMHACutlassSM100BwdRun);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment