Unverified Commit 41b611f7 authored by Zeyu WANG's avatar Zeyu WANG Committed by GitHub
Browse files

Add more GPU architctures support (#76)



* Add more GPU architctures support

* Merge fmha and mla runner

* add varlen & non varlen support, and add incontiguous tensor support

* update readme

* add varlen api

---------
Co-authored-by: default avatardianzhangc <dianzhangc@nvidia.com>
parent 9edee0c0
...@@ -28,13 +28,21 @@ Currently released: ...@@ -28,13 +28,21 @@ Currently released:
### Install ### Install
```bash ```bash
python setup.py install pip install -v .
``` ```
### Benchmark ### Benchmark
#### Testing MLA Decoding
```bash
python tests/test_flash_mla_sm90.py
```
#### Testing MLA Forward/Backward
```bash ```bash
python tests/test_flash_mla.py python tests/test_fmha_sm100.py
``` ```
It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
......
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/kernel_hardware_info.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<typename Atom, typename TA, typename TB, typename TC>
CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
constexpr int rA = decltype(rank(tA))::value;
constexpr int rB = decltype(rank(tB))::value;
constexpr int rC = decltype(rank(tC))::value;
static_assert(rA == 3 && rB == 3 && rC == 3);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tA); k_block++) {
cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
atom.accumulate_ = decltype(atom.accumulate_)::One;
}
}
template<typename Atom, typename TA, typename TB, typename TC>
CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
atom.accumulate_ = decltype(atom.accumulate_)::Zero;
gemm_reset_zero_acc(atom, tA, tB, tC);
}
template<class Layout, class Stages = _1>
CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
return composition(layout, prepend<decltype(rank(layout))::value>(make_layout(stages), _));
}
template<class T>
CUTE_DEVICE T warp_uniform(T a) {
return __shfl_sync(0xffffffff, a, 0);
}
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
CUTE_HOST_DEVICE constexpr
auto
to_tiled_mma_sm100_ts(
TiledMMA<MMA_Atom<
MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
cute::C<M>, cute::C<N>,
cute::integral_constant<UMMA::Major, a_major>,
cute::integral_constant<UMMA::Major, b_major>,
cute::integral_constant<UMMA::ScaleIn, a_neg>,
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
TAs...>, TMs...>) {
return TiledMMA<MMA_Atom<
MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, UMMA::Saturate::False>>,
TAs...>, TMs...>{};
}
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
CUTE_HOST_DEVICE constexpr
auto
to_tiled_mma_sm100_ts(
TiledMMA<MMA_Atom<
SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
M, N,
a_major,
b_major,
a_neg,
b_neg>,
TAs...>, TMs...>) {
return TiledMMA<MMA_Atom<
SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, UMMA::Saturate::False>,
TAs...>, TMs...>{};
}
template<uint32_t RegCount>
CUTLASS_DEVICE
void warpgroup_reg_set() {
if constexpr (RegCount < 128) {
cutlass::arch::warpgroup_reg_dealloc<RegCount>();
}
else {
cutlass::arch::warpgroup_reg_alloc<RegCount>();
}
}
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
struct NoMask {
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return ceil_div(get<1>(problem_size), get<1>(tile_shape));
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
return;
}
};
struct ResidualMask : NoMask {
using Base = NoMask;
template <class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return 1;
}
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// if the sequence length does not divide the tile size evenly
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (get<1>(pos) >= get<1>(problem_size)) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct ResidualMaskForBackward : NoMask {
using Base = NoMask;
template <class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return 1;
}
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// if the sequence length does not divide the tile size evenly
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (! elem_less(pos, select<0,1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
};
// There are two ways to do causal if N_Q != N_K
// (1) The Q is at the beginning of the matrix
// (2) The Q is at the end of the matrix
template<bool kIsQBegin = true>
struct CausalMask : NoMask {
using Base = NoMask;
static constexpr bool IsQBegin = kIsQBegin;
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// See note below on different ways to think about causal attention
// Again, we'd add the offset_q into the max_blocks_q calculation
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
} else {
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape);
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is the default setting.
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to set kIsQBegin=false
if constexpr (IsQBegin) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
} else {
const auto offset_q = get<1>(problem_size) - get<0>(problem_size);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
}
};
template<bool kIsQBegin = true>
struct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForBackward {
using Base = CausalMask<kIsQBegin>;
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is what we demonstrate here
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
int offset_q = 0;
if constexpr (!kIsQBegin) {
offset_q = get<1>(problem_size) - get<0>(problem_size);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size);
if (masked) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct VariableLength {
int max_length;
int* cumulative_length = nullptr;
int total_length = -1;
CUTE_HOST_DEVICE operator int() const {
return max_length;
}
};
template<class T> struct is_variable_length_impl : std::false_type {};
template<> struct is_variable_length_impl<VariableLength> : std::true_type {};
template<class T> constexpr bool is_variable_length_v = is_variable_length_impl<remove_cvref_t<T>>::value;
template<class Shape, class Idx>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length(Shape const& shape, Idx const& idx) {
return transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
return s;
}
});
}
template<class Shape, class Coord, class Idx>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
auto new_shape = apply_variable_length(shape, idx);
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
if constexpr (is_variable_length_v<decltype(s)>) {
return cute::make_tuple(c, s.cumulative_length[idx]);
}
else {
return c;
}
});
return cute::make_tuple(new_shape, new_coord);
}
template<class Shape, class Coord>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length_offset(Shape const& shape, Coord const& coord) {
auto idx = back(back(coord));
auto result_shape = transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
return s;
}
});
auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx];
}
else {
return _0{};
}
});
return cute::make_tuple(result_shape, result_offset);
}
} // namespace cutlass::fmha::collective
namespace cute {
template<>
struct is_integral<cutlass::fmha::collective::VariableLength> : true_type {};
CUTE_HOST_DEVICE
void print(cutlass::fmha::collective::VariableLength a) {
printf("Varlen<%d, %p>", a.max_length, a.cumulative_length);
}
}
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
namespace cutlass::fmha::collective {
template<
class Element,
class ElementAcc,
class TileShape, // Q, D, _
class StrideO, // Q, D, B
class StrideLSE_, // Q, B
class OrderLoadEpilogue = cute::false_type
>
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
using Pipeline = cutlass::PipelineAsync<2>;
// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>());
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
using ElementOut = Element;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct TensorStorage {
using SmemLayoutO = SmemLayoutO_;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct Arguments {
Element* ptr_O;
StrideO dO;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
};
using TMA_O = decltype(make_tma_copy(
SM90_TMA_STORE{},
make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}),
SmemLayoutO{}(_,_,_0{})
));
struct Params {
TMA_O tma_store_o;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
};
// FMHA and MLA have different input ProblemShapes;
// get problem_shape_O according to the input ProblemShape.
template<class ProblemShape>
CUTLASS_DEVICE static constexpr
auto get_problem_shape_O (
ProblemShape const& problem_shape) {
if constexpr (rank_v<decltype(get<2>(ProblemShape{}))> == 2) {
return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape));
} else {
return select<0,2,3>(problem_shape);
}
}
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace = nullptr) {
auto ptr_O = args.ptr_O;
StrideO dO = args.dO;
auto problem_shape_O = get_problem_shape_O(problem_shape);
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dO) = get<0>(dO);
get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O));
// offset ptr by the amount we add back in later
ptr_O -= max_length_q * get<0>(dO);
}
}
auto tma_store_o = make_tma_copy(
SM90_TMA_STORE{},
make_tensor(ptr_O, problem_shape_O, dO),
SmemLayoutO{}(_,_,_0{})
);
return {
tma_store_o,
args.ptr_LSE,
args.dLSE
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
}
const Params& params;
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE auto
store(
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& shared_storage,
Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) {
BlkCoord blk_coord = blk_coord_in;
uint32_t lane_predicate = cute::elect_one_sync();
using X = Underscore;
int o0_index = 2 * get<0>(blk_coord);
int o1_index = 2 * get<0>(blk_coord) + 1;
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape));
// offset mode 0 by (max_length - real_length)
// offset mode 3,1 by cumulative_length + real_length
// the ptr is already offset by - max_length
// so in total this achieves
int offs_0 = 0;
int offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
offs_0 = max_length_q - get<0>(problem_shape);
offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape);
get<2,1>(blk_coord) = 0;
}
}
Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p);
Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{});
Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord));
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto block_tma = params.tma_store_o.get_slice(0);
Tensor tOsO = block_tma.partition_S(sO);
Tensor tOgO = block_tma.partition_D(gO);
auto pipeline_release_state = pipeline_consumer_state;
// O1 O2
// one pipeline: O
// wait from corr, issue tma store on smem
pipeline.consumer_wait(pipeline_consumer_state);
++pipeline_consumer_state;
if (lane_predicate) {
copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index));
}
tma_store_arrive();
pipeline.consumer_wait(pipeline_consumer_state);
++pipeline_consumer_state;
if (lane_predicate) {
copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index));
}
tma_store_arrive();
tma_store_wait<1>();
pipeline.consumer_release(pipeline_release_state);
++pipeline_release_state;
tma_store_wait<0>();
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
pipeline.consumer_release(pipeline_release_state);
++pipeline_release_state;
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "collective/fmha_common.hpp"
#include "collective/fmha_fusion.hpp"
#include "collective/sm100_fmha_load_tma_warpspecialized.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element_,
class ElementQK_,
class ElementPV_,
class TileShape_,
class StrideQ_,
class StrideK_,
class StrideV_,
class Mask_,
// shape here is QG K H
// and referes to the two softmax warps
// (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V)
// (1, 2, 1) means they sit side by side (best for small Q / large K)
class ThreadShape = Shape<_2, _1, _1>,
// Since shared memory is sufficient for FMHA, there is no need to reuse shared memory.
class OrderLoadEpilogue = cute::false_type
>
struct Sm100FmhaFwdMainloopTmaWarpspecialized {
using Element = Element_;
using ElementQK = ElementQK_;
using ElementPV = ElementPV_;
using TileShape = TileShape_;
using StrideQ = StrideQ_;
using StrideK = StrideK_;
using StrideV = StrideV_;
using Mask = Mask_;
static constexpr int StageCountQ = 2;
static constexpr int StageCountKV = sizeof(Element_) == 1 ? 4 : 3;
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;
using ClusterShape = Shape<_1, _1, _1>;
static const int Alignment = 128 / sizeof_bits_v<Element>;
using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{}));
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, StrideQ, Alignment,
Element, StrideK, Alignment,
ElementQK,
TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// the stride for A does not matter since we do not load from smem at all
Element, StrideK, Alignment,
Element, decltype(select<1,0,2>(StrideV{})), Alignment,
ElementPV,
TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StageCountQ>{}));
using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int<StageCountKV>{}));
using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int<StageCountKV>{}));
// Reuse shared memory for V and O.
static constexpr bool IsOrderLoadEpilogue = std::is_same_v<OrderLoadEpilogue, cute::true_type>;
struct TensorStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
};
enum class TmemAllocation : uint32_t {
kSizeS = 128,
kSizeO = 128,
kSizeP = 32,
S0 = 0,
S1 = S0 + kSizeS,
V0 = S0, // stats storage from softmax to correction
V1 = S1,
P0 = S0 + kSizeP,
P1 = S1 + kSizeP,
O0 = S1 + kSizeS,
O1 = O0 + kSizeO,
kEnd = O1 + kSizeO
};
// indices for V0 / V1
enum : int {
kIdxOldRowMax = 0,
kIdxNewRowMax = 1,
kIdxFinalRowSum = 0,
kIdxFinalRowMax = 1
};
// from load to mma warp, protects q in smem
using PipelineQ = cutlass::PipelineTmaUmmaAsync<
StageCountQ,
typename CollectiveMmaQK::AtomThrShapeMNK
>;
// from load to mma warp, protects k/v in smem
using PipelineKV = cutlass::PipelineTmaUmmaAsync<
StageCountKV,
typename CollectiveMmaQK::AtomThrShapeMNK
>;
// from mma to softmax0/1 warp, protects S in tmem
// (not sure yet about the reverse direction)
// there is one pipe per softmax warp, and the mma warp alternates between them
using PipelineS = cutlass::PipelineUmmaAsync<1>;
// from softmax0/1/ to correction wg
using PipelineC = cutlass::PipelineAsync<1>;
// from mma to correction
using PipelineO = cutlass::PipelineUmmaAsync<2>;
// from corr to epilogue
using PipelineE = cutlass::PipelineAsync<2>;
using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier<
/*stages*/ 1, /*groups*/ 2>;
static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size");
using Load = Sm100FmhaLoadTmaWarpspecialized<
Element, StrideQ, StrideK, StrideV,
CollectiveMmaQK, CollectiveMmaPV,
SmemLayoutQ, SmemLayoutK, SmemLayoutV,
TensorStorage, PipelineQ, PipelineKV, Mask, TileShape
>;
struct Arguments {
typename Load::Arguments load;
// if zero, defaults to 1/sqrt(D)
float scale_softmax = 0.0f;
// scaling factors to dequantize QKV
float scale_q = 1.0f;
float scale_k = 1.0f;
float scale_v = 1.0f;
// scaling factor to quantize O
float inv_scale_o = 1.0f;
};
struct Params {
typename Load::Params load;
float scale_softmax;
float scale_softmax_log2;
float scale_output;
};
template<class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
float scale_softmax = args.scale_softmax;
if (scale_softmax == 0.0f) {
scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape));
}
float log2_e = static_cast<float>(std::log2(std::exp(1.0)));
return Params{
Load::to_underlying_arguments(problem_shape, args.load, workspace),
args.scale_q * args.scale_k * scale_softmax,
args.scale_q * args.scale_k * log2_e * scale_softmax,
args.scale_v * args.inv_scale_o
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
Load::prefetch_tma_descriptors(params.load);
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
Load load;
load.load(blk_coord, problem_shape, params.load, params_problem_shape,
storage,
pipeline_q, pipeline_q_producer_state,
pipeline_kv, pipeline_kv_producer_state);
}
template<class BlkCoord, class ProblemShape>
CUTLASS_DEVICE auto
mma(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state,
PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state,
PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state,
PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) {
auto pipeline_q_release_state = pipeline_q_consumer_state;
auto pipeline_kv_release_state = pipeline_kv_consumer_state;
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
typename CollectiveMmaQK::TiledMma mma_qk;
ThrMMA thr_mma_qk = mma_qk.get_slice(0);
typename CollectiveMmaPV::TiledMma mma_pv;
TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv);
ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ);
Tensor tSrK = thr_mma_qk.make_fragment_B(sK);
Tensor tOrV = thr_mma_pv.make_fragment_B(sV);
// tmem layout is
// S0 S1`O0 O1
// sequential in memory, where S overlaps with P and V
Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{}));
Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{}));
Tensor tStS0 = tStS;
tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0);
Tensor tStS1 = tStS;
tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1);
Tensor tOtO0 = tOtO;
tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0);
Tensor tOtO1 = tOtO;
tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1);
Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{});
Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging
Tensor tOrP0 = tOrP;
tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0);
Tensor tOrP1 = tOrP;
tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1);
int k_index = 0;
int v_index = 0;
int q_index = 0;
// wait for Q1
q_index = pipeline_q_consumer_state.index();
pipeline_q.consumer_wait(pipeline_q_consumer_state);
++pipeline_q_consumer_state;
Tensor tSrQ0 = tSrQ(_,_,_,q_index);
// wait for K1
k_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm Q1 * K1 -> S1
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0);
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
// release K1
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
// wait for Q2
if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) {
q_index = pipeline_q_consumer_state.index();
pipeline_q.consumer_wait(pipeline_q_consumer_state);
++pipeline_q_consumer_state;
}
Tensor tSrQ1 = tSrQ(_,_,_,q_index);
if constexpr (get<1>(ThreadShape{}) > 1) {
k_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
// gemm Q2 * K1 -> S2
gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1);
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// release K1
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
// wait for V1
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// this acquire returns the ownership of all of S0 to the mma warp
// including the P0 part
// acquire corr first to take it out of the critical
// path since softmax takes longer
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
// gemm P1 * V1 -> O1
gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// wait for Ki
k_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm Q1 * Ki -> S1
gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0);
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
// gemm P2 * V(i-1) -> O2
if constexpr (get<1>(ThreadShape{}) > 1) {
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
// release V(i-1)
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
k_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
// gemm Q2 * Ki -> S2
gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1);
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// release Ki
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
// wait for Vi
v_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm P1 * Vi -> O1
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
}
// release Q1
pipeline_q.consumer_release(pipeline_q_release_state);
++pipeline_q_release_state;
// release Q2
if constexpr (get<0>(ThreadShape{}) > 1) {
pipeline_q.consumer_release(pipeline_q_release_state);
++pipeline_q_release_state;
}
// wait for Vi
if constexpr (get<1>(ThreadShape{}) > 1) {
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
// gemm P2 * Vi -> O2
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
// release Vi
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ...
// Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ...
}
template<bool need_apply_mask, class Stage, class BlkCoord, class CoordTensor, class ProblemShape>
CUTLASS_DEVICE auto
softmax_step(
float& row_max, float& row_sum,
Stage stage, bool final_call,
BlkCoord const& blk_coord, CoordTensor const& cS,
Params const& params, ProblemShape const& problem_shape,
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
OrderBarrierSoftmax& order_s) {
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));
tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS);
Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS);
auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v);
auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx);
Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v);
Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P);
tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get());
Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P);
// wait on tensor core pipe
pipeline_s.consumer_wait(pipeline_s_consumer_state);
// read all of S from tmem into reg mem
Tensor tTMEM_LOADrS = make_tensor<ElementQK>(shape(tTMEM_LOADcS));
copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS);
if constexpr (need_apply_mask) {
Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape);
}
ElementQK old_row_max = row_max;
{
// compute rowmax
float row_max_0 = row_max;
float row_max_1 = row_max;
float row_max_2 = row_max;
float row_max_3 = row_max;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 4) {
row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i));
row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1));
row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2));
row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3));
}
row_max = ::fmax(row_max_0, row_max_1);
row_max = ::fmax(row_max, row_max_2);
row_max = ::fmax(row_max, row_max_3);
}
ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max;
Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));
tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max;
tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe;
copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);
pipeline_c.producer_commit(pipeline_c_producer_state);
++pipeline_c_producer_state;
// notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's)
ElementQK scale = params.scale_softmax_log2;
ElementQK row_max_scale = row_max_safe * scale;
float2 scale_fp32x2 = make_float2(scale, scale);
float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale);
Tensor tTMEM_STORErS_x4 = make_tensor<uint32_t>(shape(tTMEM_STOREcS));
constexpr int kConversionsPerStep = 2;
Tensor tTMEM_STORErS_x4_e = recast<Array<Element, kConversionsPerStep>>(tTMEM_STORErS_x4);
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
const int kReleasePipeCount = 10; // must be multiple of 2
order_s.wait();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 2) {
float2 in = make_float2(
tTMEM_LOADrS(i + 0),
tTMEM_LOADrS(i + 1)
);
float2 out;
cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2);
tTMEM_LOADrS(i + 0) = out.x;
tTMEM_LOADrS(i + 1) = out.y;
tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0));
tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1));
Array<ElementQK, kConversionsPerStep> in_conv;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kConversionsPerStep; j++) {
in_conv[j] = tTMEM_LOADrS(i + j);
}
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
order_s.arrive();
}
// this prevents register spills in fp16
if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) {
if (i == size(tTMEM_LOADrS) - 6) {
copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0));
}
}
}
// tmem_store(reg_S8) -> op_P
CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{});
CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{});
copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1));
cutlass::arch::fence_view_async_tmem_store();
// notify tensor core warp that P is ready
pipeline_s.consumer_release(pipeline_s_consumer_state);
++pipeline_s_consumer_state;
pipeline_c.producer_acquire(pipeline_c_producer_state);
ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
row_sum *= acc_scale;
// row_sum = sum(reg_S)
float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum);
float2 local_row_sum_1 = make_float2(0, 0);
float2 local_row_sum_2 = make_float2(0, 0);
float2 local_row_sum_3 = make_float2(0, 0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 8) {
// row_sum += tTMEM_LOADrS(i);
float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1));
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in);
in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1));
cute::add(local_row_sum_1, local_row_sum_1, in);
in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1));
cute::add(local_row_sum_2, local_row_sum_2, in);
in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1));
cute::add(local_row_sum_3, local_row_sum_3, in);
}
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1);
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
row_sum = local_row_sum;
if (final_call) {
// re-acquire the S part in the final step
pipeline_s.consumer_wait(pipeline_s_consumer_state);
Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));
tTMEM_STOREVrS(kIdxFinalRowMax) = row_max;
tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum;
copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);
}
}
template<class Stage, class BlkCoord, class ProblemShape>
CUTLASS_DEVICE auto
softmax(
Stage stage,
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
OrderBarrierSoftmax& order_s) {
int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape);
ElementQK row_max = -INFINITY;
ElementQK row_sum = 0;
Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{}));
auto logical_offset = make_coord(
get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}),
0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{})
);
Tensor cS = domain_offset(logical_offset, cS_base);
pipeline_c.producer_acquire(pipeline_c_producer_state);
CUTLASS_PRAGMA_NO_UNROLL
for (; mask_tile_count > 0; mask_tile_count -= 1) {
softmax_step<false /* need_apply_mask */>(
row_max, row_sum, stage,
(mask_tile_count == 1) &&
(Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0),
blk_coord, cS, params, problem_shape,
pipeline_s, pipeline_s_consumer_state,
pipeline_c, pipeline_c_producer_state,
order_s
);
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
}
// Masked iterations
mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape);
CUTLASS_PRAGMA_NO_UNROLL
for (; mask_tile_count > 0; mask_tile_count -= 1) {
softmax_step<true /* need_apply_mask */>(
row_max, row_sum, stage, mask_tile_count == 1,
blk_coord, cS, params, problem_shape,
pipeline_s, pipeline_s_consumer_state,
pipeline_c, pipeline_c_producer_state,
order_s
);
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
}
pipeline_c.producer_commit(pipeline_c_producer_state);
++pipeline_c_producer_state;
pipeline_c.producer_acquire(pipeline_c_producer_state);
// empty step to sync against pipe s
pipeline_s.consumer_release(pipeline_s_consumer_state);
++pipeline_s_consumer_state;
}
template<class Stage, class TensorO>
CUTLASS_DEVICE auto
correction_epilogue(
float scale,
Stage stage,
TensorO const& sO_01) {
using ElementOut = typename TensorO::value_type;
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Tensor sO = sO_01(_,_,stage);
// As opposed to the softmax, we do not have enough registers here
// to load all of the values (for tile kv = 128), so we loop
// good values would be either 32 or 64
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOsO = mma.get_slice(0).partition_C(sO);
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
if constexpr (decltype(stage == _0{})::value) {
tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0);
}
else {
static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1");
tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1);
}
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
float2 scale_f32x2 = make_float2(scale, scale);
// loop:
// TMEM_LOAD, FMUL2 scale, TMEM_STORE
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i);
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
#ifndef ONLY_SOFTMAX
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO); j += 2) {
float2 in = make_float2(tTMrO(j), tTMrO(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO(j) = out.x;
tTMrO(j+1) = out.y;
}
#endif
constexpr int N = 4 / sizeof(ElementOut);
NumericArrayConverter<ElementOut, ElementPV, N> convert;
Tensor tSMrO = make_tensor_like<ElementOut>(tTMrO);
Tensor tCs = recast<decltype(convert)::source_type>(tTMrO);
Tensor tCd = recast<decltype(convert)::result_type>(tSMrO);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tCs); j++) {
tCd(j) = convert.convert(tCs(j));
}
Tensor tSMsO_i = recast<uint32_t>(tTMEM_LOADsO_i);
Tensor tSMrO_i = recast<uint32_t>(tSMrO);
copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i);
}
cutlass::arch::fence_view_async_shared();
}
CUTLASS_DEVICE auto
correction_rescale(
float scale,
uint32_t tmem_O) {
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
// As opposed to the softmax, we do not have enough registers here
// to load all of the values (for tile kv = 128), so we loop
// good values would be either 32 or 64
const int kCorrectionTileSize = 16;
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i);
static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO));
float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i);
};
auto copy_out = [&](int i) {
Tensor tTMEM_STOREtO_i = tTMEM_STOREtO;
tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i);
};
// sequence: LLMSLMSLMSS
// loop:
// TMEM_LOAD, FMUL2 scale, TMEM_STORE
copy_in(0);
int count = get<2>(TileShape{}) / kCorrectionTileSize;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < count; i++) {
if (i != count - 1) {
copy_in(i+1);
}
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO_i); j += 2) {
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO_i(j) = out.x;
tTMrO_i(j+1) = out.y;
}
copy_out(i);
}
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v);
auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx);
Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v);
Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v);
Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS;
tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0);
Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS;
tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1);
// ignore first signal from softmax as no correction is required
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
// handle the last iteration differently (i.e. tmem_load/stsm for epi)
mask_tile_count -= 1;
CUTLASS_PRAGMA_NO_UNROLL
for (; mask_tile_count > 0; mask_tile_count -= 1) {
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));
// read row_wise new global max
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
// e^(scale * (old_max - new_max)
float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O0));
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
cutlass::arch::fence_view_async_tmem_store();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O1));
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
cutlass::arch::fence_view_async_tmem_store();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
}
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
// do the final correction to O1
// better to somehow special-case it in the loop above
// doesn't matter for non-persistent code, but if it were
// persistent we do not want to release O too early
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
// read from V0
// read row_sum and final row_max here
Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
pipeline_o.consumer_wait(pipeline_o_consumer_state);
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
// store to epi smem
// loop:
// TMEM_LOAD
// FMUL2 scale = 1 / global_sum * out_quant_scale
// F2FP
// store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
// load from V1
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
pipeline_o.consumer_wait(pipeline_o_consumer_state);
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction_empty(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if (threadIdx.x % 128 == 0 && block0()) {
DSHOW(sO);
}
#if 1
using ElementOut = typename CollectiveEpilogue::ElementOut;
auto tiled_copy = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),
sO.layout());
auto thr_copy = tiled_copy.get_slice(thread_idx);
auto tOgO = thr_copy.partition_D(sO);
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));
clear(tOrO);
copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));
#endif
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "collective/fmha_common.hpp"
#include "collective/fmha_fusion.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element,
class StrideQ,
class StrideK,
class StrideV,
class CollectiveMmaQK,
class CollectiveMmaPV,
class SmemLayoutQ,
class SmemLayoutK,
class SmemLayoutV,
class TensorStorage,
class PipelineQ,
class PipelineKV,
class Mask,
class TileShape
>
struct Sm100FmhaLoadTmaWarpspecialized {
using TileShapeQK = typename CollectiveMmaQK::TileShape;
using TileShapePV = typename CollectiveMmaPV::TileShape;
struct Arguments {
const Element* ptr_Q;
StrideQ dQ;
const Element* ptr_K;
StrideK dK;
const Element* ptr_V;
StrideV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
};
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
auto ptr_Q = args.ptr_Q;
auto ptr_K = args.ptr_K;
auto ptr_V = args.ptr_V;
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
auto problem_shape_qk = problem_shape;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dQ) = get<0>(dQ);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_Q -= max_length_q * get<0>(dQ);
}
}
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
if (cumulative_length_kv != nullptr) {
int max_length_kv = get<1>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dK) = get<0>(dK);
get<2,1>(dV) = get<0>(dV);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_K -= max_length_kv * get<0>(dK);
ptr_V -= max_length_kv * get<0>(dV);
}
}
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
problem_shape_qk,
typename CollectiveMmaQK::Arguments {
ptr_Q, dQ,
ptr_K, dK,
}, /*workspace=*/ nullptr);
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
problem_shape_pv,
typename CollectiveMmaPV::Arguments {
ptr_K, dK, // never used, dummy
ptr_V, select<1,0,2>(dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
BlkCoord blk_coord_q = blk_coord_in;
BlkCoord blk_coord_kv = blk_coord_in;
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
using X = Underscore;
// this one is only executed by one thread, no need to elect_one
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
// two pipes: Q and KV
// from Memory (prod) to TensorCore (cons)
// compute gQ, sQ
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
int q_offs_0 = 0;
int q_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
q_offs_0 = max_length_q - get<0>(problem_shape);
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
auto [tQgQ_qdl, tQsQ] = tma_partition(
params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
);
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
// compute gK, sK
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
int kv_offs_0 = 0;
int kv_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
int max_length = get<1>(params_problem_shape).max_length;
kv_offs_0 = max_length - get<1>(problem_shape);
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
auto [tKgK_kdl, tKsK] = tma_partition(
params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
);
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
// compute gV, sV
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
auto [tVgV_dkl, tVsV] = tma_partition(
params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
);
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
// As such, it needs to be transformed as
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
uint32_t lane_predicate = cute::elect_one_sync();
// Q1
int q0_index = 2 * get<0>(blk_coord_q);
int q1_index = 2 * get<0>(blk_coord_q) + 1;
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
// K1
int k_index = 0;
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
// Q2
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
// V1
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
k_index += 1;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// Ki
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
// Vi
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
k_index += 1;
}
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "collective/fmha_common.hpp"
#include "collective/fmha_fusion.hpp"
#include "collective/sm100_fmha_mla_load_tma_warpspecialized.hpp"
#include "common/pipeline_mla.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element_,
class ElementQK_,
class ElementPV_,
class ComposedTileShape_,
class StrideQ_,
class StrideK_,
class StrideV_,
class Mask_,
// shape here is QG K H
// and referes to the two softmax warps
// (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V)
// (1, 2, 1) means they sit side by side (best for small Q / large K)
class ThreadShape = Shape<_2, _1, _1>,
class OrderLoadEpilogue = cute::false_type
>
struct Sm100MlaFwdMainloopTmaWarpspecialized {
using Element = Element_;
using ElementQK = ElementQK_;
using ElementPV = ElementPV_;
using ComposedTileShape = ComposedTileShape_;
using StrideQ = StrideQ_;
using StrideK = StrideK_;
using StrideV = StrideV_;
using Mask = Mask_;
static constexpr int StageCountQ = 2;
static constexpr int StageCountK = 1;
static constexpr int StageCountV = 1;
static constexpr int StageCountKV = StageCountK + StageCountV;
// Support StageCountKV > 2 in the future.
static_assert(StageCountK == 1 && StageCountV == 1, "Only support StageCountK = StageCountV = 1!");
static_assert(std::is_same_v<ThreadShape, Shape<_2, _1, _1>>, "Only support ThreadShape = Shape<_2, _1, _1>");
using ClusterShape = Shape<_1, _1, _1>;
static const int Alignment = 128 / sizeof_bits_v<Element>;
static constexpr auto HeadDimLatent = size<2, 0>(ComposedTileShape{});
static constexpr auto HeadDimRope = size<2, 1>(ComposedTileShape{});
static constexpr auto HeadDimQK = HeadDimLatent + HeadDimRope;
static constexpr auto HeadDimPV = HeadDimLatent;
using TileShapeQK = decltype(shape_div(replace<2>(ComposedTileShape{}, HeadDimQK), ThreadShape{}));
using TileShapePV = decltype(select<0,2,1>(shape_div(replace<2>(ComposedTileShape{}, HeadDimPV), ThreadShape{})));
using TileShape = decltype(replace<2>(ComposedTileShape{}, HeadDimLatent));
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, StrideQ, Alignment,
Element, StrideK, Alignment,
ElementQK,
TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// the stride for A does not matter since we do not load from smem at all
Element, StrideK, Alignment,
Element, decltype(select<1,0,2>(StrideV{})), Alignment,
ElementPV,
TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StageCountQ>{}));
using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int<StageCountK>{}));
using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int<StageCountV>{}));
using SmemStorageOneStageO = decltype(make_layout(replace<2>(TileShapePV{}, _1{})));
// Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory,
// we reuse shared memory for V and O to address this problem,
// and a barrier has been added to coordinate access to shared memory.
static constexpr bool IsOrderLoadEpilogue = std::is_same_v<OrderLoadEpilogue, cute::true_type>;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct TensorStorageQKVO {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_o; // use as O0
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v; // use as V0 and O1
};
struct TensorStorageQKV {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
using TensorStorage = std::conditional_t<IsOrderLoadEpilogue, TensorStorageQKVO, TensorStorageQKV>;
enum class TmemAllocation : uint32_t {
kSizeS = 128,
kSizeO = 128,
kSizeP = 32,
S0 = 0,
S1 = S0 + kSizeS,
V0 = S0, // stats storage from softmax to correction
V1 = S1,
P0 = S0 + kSizeP,
P1 = S1 + kSizeP,
O0 = S1 + kSizeS,
O1 = O0 + kSizeO,
kEnd = O1 + kSizeO
};
// indices for V0 / V1
enum : int {
kIdxOldRowMax = 0,
kIdxNewRowMax = 1,
kIdxFinalRowSum = 0,
kIdxFinalRowMax = 1
};
// from load to mma warp, protects q in smem
using PipelineQ = cutlass::PipelineTmaUmmaAsync<
StageCountQ,
typename CollectiveMmaQK::AtomThrShapeMNK
>;
// from load to mma warp, protects k/v in smem
using PipelineKV = cutlass::PipelineTmaAsyncMla<
StageCountKV,
typename CollectiveMmaQK::AtomThrShapeMNK
>;
// from mma to softmax0/1 warp, protects S in tmem
// (not sure yet about the reverse direction)
// there is one pipe per softmax warp, and the mma warp alternates between them
using PipelineS = cutlass::PipelineUmmaAsync<1>;
// from softmax0/1/ to correction wg
using PipelineC = cutlass::PipelineAsync<1>;
// from mma to correction
using PipelineO = cutlass::PipelineUmmaAsync<2>;
// from corr to epilogue
using PipelineE = cutlass::PipelineAsync<2>;
using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier<
/*stages*/ 1, /*groups*/ 2>;
static constexpr int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
using Load = Sm100MlaFwdLoadTmaWarpspecialized<
Element, StrideQ, StrideK, StrideV,
CollectiveMmaQK, CollectiveMmaPV,
SmemLayoutQ, SmemLayoutK, SmemLayoutV,
TensorStorage, PipelineQ, PipelineKV, Mask, TileShape, OrderLoadEpilogue
>;
struct Arguments {
typename Load::Arguments load;
// if zero, defaults to 1/sqrt(D)
float scale_softmax = 0.0f;
// scaling factors to dequantize QKV
float scale_q = 1.0f;
float scale_k = 1.0f;
float scale_v = 1.0f;
// scaling factor to quantize O
float inv_scale_o = 1.0f;
};
struct Params {
typename Load::Params load;
float scale_softmax;
float scale_softmax_log2;
float scale_output;
};
template<class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
float scale_softmax = args.scale_softmax;
if (scale_softmax == 0.0f) {
scale_softmax = 1.0f / (float) std::sqrt(get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
}
float log2_e = static_cast<float>(std::log2(std::exp(1.0)));
return Params{
Load::to_underlying_arguments(problem_shape, args.load, workspace),
args.scale_q * args.scale_k * scale_softmax,
args.scale_q * args.scale_k * log2_e * scale_softmax,
args.scale_v * args.inv_scale_o
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
Load::prefetch_tma_descriptors(params.load);
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
Load load;
load.load(blk_coord, problem_shape, params.load, params_problem_shape,
storage,
pipeline_q, pipeline_q_producer_state,
pipeline_kv, pipeline_kv_producer_state);
}
template<class BlkCoord, class ProblemShape>
CUTLASS_DEVICE auto
mma(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state,
PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state,
PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state,
PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) {
auto pipeline_q_release_state = pipeline_q_consumer_state;
auto pipeline_kv_release_state = pipeline_kv_consumer_state;
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
typename CollectiveMmaQK::TiledMma mma_qk;
ThrMMA thr_mma_qk = mma_qk.get_slice(0);
typename CollectiveMmaPV::TiledMma mma_pv;
TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv);
ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ);
Tensor tSrK = thr_mma_qk.make_fragment_B(sK);
Tensor tOrV = thr_mma_pv.make_fragment_B(sV);
// tmem layout is
// S0 S1`O0 O1
// sequential in memory, where S overlaps with P and V
Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{}));
Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{}));
Tensor tStS0 = tStS;
tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0);
Tensor tStS1 = tStS;
tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1);
Tensor tOtO0 = tOtO;
tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0);
Tensor tOtO1 = tOtO;
tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1);
Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{});
Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging
Tensor tOrP0 = tOrP;
tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0);
Tensor tOrP1 = tOrP;
tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1);
int k_index = 0;
int v_index = 0;
int q_index = 0;
// wait for Q1
q_index = pipeline_q_consumer_state.index();
pipeline_q.consumer_wait(pipeline_q_consumer_state);
++pipeline_q_consumer_state;
Tensor tSrQ0 = tSrQ(_,_,_,q_index);
// wait for K1
k_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm Q1 * K1 -> S1
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0);
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
// release K1
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
// wait for Q2
if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) {
q_index = pipeline_q_consumer_state.index();
pipeline_q.consumer_wait(pipeline_q_consumer_state);
++pipeline_q_consumer_state;
}
Tensor tSrQ1 = tSrQ(_,_,_,q_index);
if constexpr (get<1>(ThreadShape{}) > 1) {
k_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
// gemm Q2 * K1 -> S2
gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1);
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// release K1
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
// wait for V1
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// this acquire returns the ownership of all of S0 to the mma warp
// including the P0 part
// acquire corr first to take it out of the critical
// path since softmax takes longer
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
// gemm P1 * V1 -> O1
gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// wait for Ki
k_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm Q1 * Ki -> S1
gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0);
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
// gemm P2 * V(i-1) -> O2
if constexpr (get<1>(ThreadShape{}) > 1) {
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
// release V(i-1)
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
k_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
// gemm Q2 * Ki -> S2
gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1);
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// release Ki
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
// wait for Vi
v_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm P1 * Vi -> O1
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
}
// release Q1
pipeline_q.consumer_release(pipeline_q_release_state);
++pipeline_q_release_state;
// release Q2
if constexpr (get<0>(ThreadShape{}) > 1) {
pipeline_q.consumer_release(pipeline_q_release_state);
++pipeline_q_release_state;
}
// wait for Vi
if constexpr (get<1>(ThreadShape{}) > 1) {
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
// gemm P2 * Vi -> O2
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
// release Vi
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ...
// Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ...
}
template<bool need_mask, class Stage, class BlkCoord, class CoordTensor, class ProblemShape>
CUTLASS_DEVICE auto
softmax_step(
bool need_apply_mask,
float& row_max, float& row_sum,
Stage stage, bool final_call,
BlkCoord const& blk_coord, CoordTensor const& cS,
Params const& params, ProblemShape const& problem_shape,
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
OrderBarrierSoftmax& order_s) {
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));
tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS);
Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS);
auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v);
auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx);
Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v);
Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P);
tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get());
Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P);
// wait on tensor core pipe
pipeline_s.consumer_wait(pipeline_s_consumer_state);
// read all of S from tmem into reg mem
Tensor tTMEM_LOADrS = make_tensor<ElementQK>(shape(tTMEM_LOADcS));
copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS);
if constexpr (need_mask) {
if(need_apply_mask) {
Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape);
}
}
ElementQK old_row_max = row_max;
{
// compute rowmax
float row_max_0 = row_max;
float row_max_1 = row_max;
float row_max_2 = row_max;
float row_max_3 = row_max;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 4) {
row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i));
row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1));
row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2));
row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3));
}
row_max = ::fmax(row_max_0, row_max_1);
row_max = ::fmax(row_max, row_max_2);
row_max = ::fmax(row_max, row_max_3);
}
ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max;
Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));
tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max;
tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe;
copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);
pipeline_c.producer_commit(pipeline_c_producer_state);
++pipeline_c_producer_state;
// notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's)
ElementQK scale = params.scale_softmax_log2;
ElementQK row_max_scale = row_max_safe * scale;
float2 scale_fp32x2 = make_float2(scale, scale);
float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale);
Tensor tTMEM_STORErS_x4 = make_tensor<uint32_t>(shape(tTMEM_STOREcS));
constexpr int kConversionsPerStep = 2;
Tensor tTMEM_STORErS_x4_e = recast<Array<Element, kConversionsPerStep>>(tTMEM_STORErS_x4);
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
constexpr int kReleasePipeCount = 10; // must be multiple of 2
order_s.wait();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 2) {
float2 in = make_float2(
tTMEM_LOADrS(i + 0),
tTMEM_LOADrS(i + 1)
);
float2 out;
cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2);
tTMEM_LOADrS(i + 0) = out.x;
tTMEM_LOADrS(i + 1) = out.y;
tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0));
tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1));
Array<ElementQK, kConversionsPerStep> in_conv;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kConversionsPerStep; j++) {
in_conv[j] = tTMEM_LOADrS(i + j);
}
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
order_s.arrive();
}
// this prevents register spills in fp16
if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) {
if (i == size(tTMEM_LOADrS) - 6) {
copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0));
}
}
}
// tmem_store(reg_S8) -> op_P
CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{});
CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{});
copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1));
cutlass::arch::fence_view_async_tmem_store();
// notify tensor core warp that P is ready
pipeline_s.consumer_release(pipeline_s_consumer_state);
++pipeline_s_consumer_state;
pipeline_c.producer_acquire(pipeline_c_producer_state);
ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
row_sum *= acc_scale;
// row_sum = sum(reg_S)
float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum);
float2 local_row_sum_1 = make_float2(0, 0);
float2 local_row_sum_2 = make_float2(0, 0);
float2 local_row_sum_3 = make_float2(0, 0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 8) {
// row_sum += tTMEM_LOADrS(i);
float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1));
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in);
in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1));
cute::add(local_row_sum_1, local_row_sum_1, in);
in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1));
cute::add(local_row_sum_2, local_row_sum_2, in);
in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1));
cute::add(local_row_sum_3, local_row_sum_3, in);
}
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1);
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
row_sum = local_row_sum;
if (final_call) {
// re-acquire the S part in the final step
pipeline_s.consumer_wait(pipeline_s_consumer_state);
Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));
tTMEM_STOREVrS(kIdxFinalRowMax) = row_max;
tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum;
copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);
}
}
template<class Stage, class BlkCoord, class ProblemShape>
CUTLASS_DEVICE auto
softmax(
Stage stage,
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
OrderBarrierSoftmax& order_s) {
const int mask_trip_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape);
const int total_trip_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
int trip_idx = total_trip_count;
ElementQK row_max = -INFINITY;
ElementQK row_sum = 0;
Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{}));
auto logical_offset = make_coord(
get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}),
0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{})
);
Tensor cS = domain_offset(logical_offset, cS_base);
pipeline_c.producer_acquire(pipeline_c_producer_state);
constexpr bool NeedMask = !std::is_same_v<Mask, NoMask>;
CUTLASS_PRAGMA_NO_UNROLL
for (; trip_idx > 0; trip_idx -= 1) {
softmax_step<NeedMask /* need_mask */>(
trip_idx <= mask_trip_count,
row_max, row_sum, stage,
trip_idx == 1,
blk_coord, cS, params, problem_shape,
pipeline_s, pipeline_s_consumer_state,
pipeline_c, pipeline_c_producer_state,
order_s
);
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
}
pipeline_c.producer_commit(pipeline_c_producer_state);
++pipeline_c_producer_state;
pipeline_c.producer_acquire(pipeline_c_producer_state);
// empty step to sync against pipe s
pipeline_s.consumer_release(pipeline_s_consumer_state);
++pipeline_s_consumer_state;
}
template<class Stage, class TensorO>
CUTLASS_DEVICE auto
correction_epilogue(
float scale,
Stage stage,
TensorO const& sO_01) {
using ElementOut = typename TensorO::value_type;
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Tensor sO = sO_01(_,_,stage);
// As opposed to the softmax, we do not have enough registers here
// to load all of the values (for tile kv = 128), so we loop
// good values would be either 32 or 64
constexpr int kCorrectionTileSize = 32 / sizeof(ElementOut);
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOsO = mma.get_slice(0).partition_C(sO);
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
if constexpr (decltype(stage == _0{})::value) {
tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0);
}
else {
static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1");
tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1);
}
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
float2 scale_f32x2 = make_float2(scale, scale);
// loop:
// TMEM_LOAD, FMUL2 scale, TMEM_STORE
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i);
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
#ifndef ONLY_SOFTMAX
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO); j += 2) {
float2 in = make_float2(tTMrO(j), tTMrO(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO(j) = out.x;
tTMrO(j+1) = out.y;
}
#endif
constexpr int N = 4 / sizeof(ElementOut);
NumericArrayConverter<ElementOut, ElementPV, N> convert;
Tensor tSMrO = make_tensor_like<ElementOut>(tTMrO);
Tensor tCs = recast<decltype(convert)::source_type>(tTMrO);
Tensor tCd = recast<decltype(convert)::result_type>(tSMrO);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tCs); j++) {
tCd(j) = convert.convert(tCs(j));
}
Tensor tSMsO_i = recast<uint32_t>(tTMEM_LOADsO_i);
Tensor tSMrO_i = recast<uint32_t>(tSMrO);
copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i);
}
cutlass::arch::fence_view_async_shared();
}
CUTLASS_DEVICE auto
correction_rescale(
float scale,
uint32_t tmem_O) {
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
// As opposed to the softmax, we do not have enough registers here
// to load all of the values (for tile kv = 128), so we loop
// good values would be either 32 or 64
constexpr int kCorrectionTileSize = 16;
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i);
static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO));
float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i);
};
auto copy_out = [&](int i) {
Tensor tTMEM_STOREtO_i = tTMEM_STOREtO;
tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i);
};
// sequence: LLMSLMSLMSS
// loop:
// TMEM_LOAD, FMUL2 scale, TMEM_STORE
copy_in(0);
constexpr int count = get<2>(TileShape{}) / kCorrectionTileSize;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < count; i++) {
if (i != count - 1) {
copy_in(i+1);
}
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO_i); j += 2) {
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO_i(j) = out.x;
tTMrO_i(j+1) = out.y;
}
copy_out(i);
}
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v);
auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx);
Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v);
Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v);
Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS;
tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0);
Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS;
tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1);
// ignore first signal from softmax as no correction is required
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
// handle the last iteration differently (i.e. tmem_load/stsm for epi)
mask_tile_count -= 1;
CUTLASS_PRAGMA_NO_UNROLL
for (; mask_tile_count > 0; mask_tile_count -= 1) {
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));
// read row_wise new global max
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
// e^(scale * (old_max - new_max)
float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O0));
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
cutlass::arch::fence_view_async_tmem_store();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O1));
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
cutlass::arch::fence_view_async_tmem_store();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
}
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
// do the final correction to O1
// better to somehow special-case it in the loop above
// doesn't matter for non-persistent code, but if it were
// persistent we do not want to release O too early
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
// read from V0
// read row_sum and final row_max here
Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
pipeline_o.consumer_wait(pipeline_o_consumer_state);
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
// store to epi smem
// loop:
// TMEM_LOAD
// FMUL2 scale = 1 / global_sum * out_quant_scale
// F2FP
// store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
// load from V1
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
pipeline_o.consumer_wait(pipeline_o_consumer_state);
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction_empty(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if (threadIdx.x % 128 == 0 && block0()) {
DSHOW(sO);
}
#if 1
using ElementOut = typename CollectiveEpilogue::ElementOut;
auto tiled_copy = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),
sO.layout());
auto thr_copy = tiled_copy.get_slice(thread_idx);
auto tOgO = thr_copy.partition_D(sO);
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));
clear(tOrO);
copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));
#endif
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "collective/fmha_common.hpp"
#include "collective/fmha_fusion.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element,
class StrideQ,
class StrideK,
class StrideV,
class CollectiveMmaQK,
class CollectiveMmaPV,
class SmemLayoutQ,
class SmemLayoutK,
class SmemLayoutV,
class TensorStorage,
class PipelineQ,
class PipelineKV,
class Mask,
class TileShape,
class OrderLoadEpilogue = cute::false_type
>
struct Sm100MlaFwdLoadTmaWarpspecialized {
using TileShapeQK = typename CollectiveMmaQK::TileShape;
using TileShapePV = typename CollectiveMmaPV::TileShape;
static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct Arguments {
const Element* ptr_Q;
StrideQ dQ;
const Element* ptr_K;
StrideK dK;
const Element* ptr_V;
StrideV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
};
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
auto ptr_Q = args.ptr_Q;
auto ptr_K = args.ptr_K;
auto ptr_V = args.ptr_V;
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dQ) = get<0>(dQ);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_Q -= max_length_q * get<0>(dQ);
}
}
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
if (cumulative_length_kv != nullptr) {
int max_length_kv = get<1>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dK) = get<0>(dK);
get<2,1>(dV) = get<0>(dV);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_K -= max_length_kv * get<0>(dK);
ptr_V -= max_length_kv * get<0>(dV);
}
}
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
problem_shape_qk,
typename CollectiveMmaQK::Arguments {
ptr_Q, dQ,
ptr_K, dK,
}, /*workspace=*/ nullptr);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
problem_shape_pv,
typename CollectiveMmaPV::Arguments {
ptr_K, dK, // never used, dummy
ptr_V, select<1,0,2>(dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
BlkCoord blk_coord_q = blk_coord_in;
BlkCoord blk_coord_kv = blk_coord_in;
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape));
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
using X = Underscore;
// this one is only executed by one thread, no need to elect_one
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
// two pipes: Q and KV
// from Memory (prod) to TensorCore (cons)
// compute gQ, sQ
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));
int q_offs_0 = 0;
int q_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
q_offs_0 = max_length_q - get<0>(problem_shape);
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
auto [tQgQ_qdl, tQsQ] = tma_partition(
params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
);
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
// compute gK, sK
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));
int kv_offs_0 = 0;
int kv_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
int max_length = get<1>(params_problem_shape).max_length;
kv_offs_0 = max_length - get<1>(problem_shape);
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
auto [tKgK_kdl, tKsK] = tma_partition(
params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
);
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
// compute gV, sV
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
auto [tVgV_dkl, tVsV] = tma_partition(
params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
);
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
// As such, it needs to be transformed as
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
uint32_t lane_predicate = cute::elect_one_sync();
// Q1
int q0_index = 2 * get<0>(blk_coord_q);
int q1_index = 2 * get<0>(blk_coord_q) + 1;
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
// K1
int k_index = 0;
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
}
++pipeline_kv_producer_state;
// Q2
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
// V1
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
}
++pipeline_kv_producer_state;
k_index += 1;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// Ki
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
// prefetch vi
cute::prefetch(params.tma_load_v, tVgV(_, k_index));
}
++pipeline_kv_producer_state;
// Vi
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
// prefetch ki+1
if(mask_tile_count > 1) {
cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1));
}
}
++pipeline_kv_producer_state;
k_index += 1;
}
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/util/print.hpp"
namespace example {
using namespace cute;
// Empty type used to disable gather/scatter for a GEMM argument
struct NoGather
{
template<class... Ts>
NoGather(Ts...) {};
};
/// Function object that applies an index to its argument
template <class Index>
struct IndexedGather
{
CUTE_HOST_DEVICE constexpr
IndexedGather(Index const *indices = {}): indices_(indices) {}
template <typename I>
CUTE_HOST_DEVICE constexpr
Index
operator()(I i) const { return indices_[i]; }
CUTE_HOST_DEVICE friend
void
print(IndexedGather const &s) {
cute::print("Indexed");
}
Index const *indices_;
};
/// Function object that applies a stride to its argument
/// Example: StridedFunc<int,_2> gathers every other row/column
template <class Stride>
struct StridedGather
{
CUTE_HOST_DEVICE constexpr
StridedGather(Stride stride = {}): stride_(stride) {}
template <class I>
CUTE_HOST_DEVICE constexpr
auto
operator()(I i) const { return i * stride_; }
CUTE_HOST_DEVICE friend
void
print(StridedGather const &s) {
cute::print("Strided{");
print(s.stride_);
cute::print("}");
}
Stride stride_;
};
/// Custom stride object that applies a function followed by a stride
template <class Func, class Stride>
struct CustomStride
{
CUTE_HOST_DEVICE constexpr
CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {}
template <class I>
CUTE_HOST_DEVICE constexpr friend
auto
operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; }
template <class I>
CUTE_HOST_DEVICE constexpr friend
auto
operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; }
CUTE_HOST_DEVICE friend
void
print(CustomStride const & s) {
cute::print("Custom{");
print(s.func_);
cute::print(",");
print(s.stride_);
cute::print("}");
}
template<class Div>
CUTE_HOST_DEVICE constexpr friend
auto
safe_div(CustomStride const &s, Div const &div)
{
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
}
// Circumvent the requirement on make_layout that shape and stride are integral
template <class Shape>
CUTE_HOST_DEVICE constexpr friend
auto
make_layout(Shape const &shape, CustomStride const &stride)
{
return Layout<Shape, CustomStride>(shape, stride);
}
Func func_;
Stride stride_;
};
template<class Stride, class Func>
CUTLASS_HOST_DEVICE
auto
make_custom_stride_layout(Stride const &stride, Func&& func)
{
// Use a dummy shape and replace the first non-unit stride with a custom gather stride
auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
return make_layout(repeat_like(stride, _1{}),
replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
}
/// Helper function to optionally create a gather tensor
template<class Iterator, class Shape, class Stride, class Func>
CUTLASS_HOST_DEVICE
auto
make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func)
{
if constexpr (not cutlass::platform::is_same<remove_cvref_t<Func>, NoGather>::value) {
Layout matrix_layout = make_identity_layout(shape);
auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
} else {
return make_tensor(iter, shape, stride);
}
}
} // namespace example
namespace cute
{
template<int N, int I, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
upcast(Shape const& shape, Stride const& stride)
{
if constexpr (is_tuple<Shape>::value) {
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N,I>(s,d); });
} else if constexpr (is_scaled_basis<Stride>::value) {
if constexpr (Stride::mode() == I) {
return make_layout(ceil_div(shape, Int<N>{}), ceil_div(stride, Int<N>{}));
} else {
return make_layout(shape, stride);
}
} else {
return upcast<N>(shape, stride);
}
CUTE_GCC_UNREACHABLE;
}
template <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
upcast(ComposedLayout<Layout<OuterShape,OuterStride>,Offset,Layout<Shape,Stride>> const& layout)
{
// Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset
auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
// Upcast the outer layout (works as expected)
auto outer = upcast<N>(layout.layout_a());
// Upcast the accumulated offset along stride-1 mode
auto offset = as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));
// Upcast the inner layout's shape along stride-1 mode
auto inner = upcast<N,I>(layout.layout_b().shape(), layout.layout_b().stride());
return composition(outer, offset, inner);
}
} // namespace example
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cuda_runtime.h"
#include <iostream>
/**
* Panic wrapper for unwinding CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
/**
* Panic wrapper for unwinding CUDA runtime errors
*/
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
#define FLASH_MLA_ASSERT(cond) \
do { \
if (!(cond)) { \
std::cerr << "FLASH_MLA_ASSERT: " << #cond << " failed at " << __FILE__ << ":" << __LINE__ << std::endl; \
std::abort(); \
} \
} while (0)
\ No newline at end of file
#pragma once
enum class MaskMode {
kNone = 0U, // No mask
kCausal = 1U, // Causal mask
kCustom = 2U, // Custom mask
};
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Support the producer to acquire specific bytes of data.
*/
#pragma once
#include "cutlass/pipeline/sm100_pipeline.hpp"
namespace cutlass {
using namespace cute;
template <
int Stages_,
class ClusterShape = Shape<int,int,_1>,
class AtomThrShape_MNK_ = Shape<_1,_1,_1>
>
class PipelineTmaAsyncMla {
public:
static constexpr uint32_t Stages = Stages_;
using AtomThrShape_MNK = AtomThrShape_MNK_;
private:
using Impl = PipelineTmaUmmaAsync<Stages_, ClusterShape, AtomThrShape_MNK_>;
public:
using FullBarrier = typename Impl::FullBarrier;
using EmptyBarrier = typename Impl::EmptyBarrier;
using ProducerBarrierType = typename Impl::ProducerBarrierType;
using ConsumerBarrierType = typename Impl::ConsumerBarrierType;
using PipelineState = typename Impl::PipelineState;
using SharedStorage = typename Impl::SharedStorage;
using ThreadCategory = typename Impl::ThreadCategory;
using Params = typename Impl::Params;
using McastDirection = McastDirection;
// Helper function to initialize barriers
static
CUTLASS_DEVICE
void
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) {
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
auto atom_thr_shape = AtomThrShape_MNK{};
uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) +
(cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1;
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
static
CUTLASS_DEVICE
void
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) {
auto atom_thr_shape = AtomThrShape_MNK{};
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ?
cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas
cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) {
// Calculate consumer mask
if (params_.role == ThreadCategory::Consumer) {
auto cluster_layout = make_layout(cluster_shape);
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRowCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) {
// Calculate consumer mask
dim3 block_id_in_cluster = cute::block_id_in_cluster();
auto cluster_layout = make_layout(cluster_shape);
if (mcast_direction == McastDirection::kRow) {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRow>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
else {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
public:
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
, params_(params)
, empty_barrier_ptr_(&storage.empty_barrier_[0])
, full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape);
}
}
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
, params_(params)
, empty_barrier_ptr_(&storage.empty_barrier_[0])
, full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape, mcast_direction);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape, mcast_direction);
}
}
CUTLASS_DEVICE
void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.producer_acquire(state, barrier_token);
}
CUTLASS_DEVICE
void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) {
detail::pipeline_check_is_producer(params_.role);
if (barrier_token != BarrierStatus::WaitDone) {
empty_barrier_ptr_[stage].wait(phase);
}
if (params_.is_leader) {
full_barrier_ptr_[stage].arrive_and_expect_tx(bytes);
}
#ifndef NDEBUG
if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) {
asm volatile ("brkpt;\n" ::);
}
// Most likely you have elected more than one leader
if (params_.is_leader && (threadIdx.x % 32 != 0)) {
asm volatile ("brkpt;\n" ::);
}
#endif
}
CUTLASS_DEVICE
void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token);
}
CUTLASS_DEVICE
ProducerBarrierType* producer_get_barrier(PipelineState state) {
return impl_.producer_get_barrier(state);
}
CUTLASS_DEVICE
void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.consumer_wait(state, barrier_token);
}
CUTLASS_DEVICE
void consumer_release(PipelineState state) {
consumer_release(state.index(), false);
}
private:
Impl impl_;
Params params_;
EmptyBarrier *empty_barrier_ptr_;
FullBarrier *full_barrier_ptr_;
uint16_t block_id_mask_ = 0;
static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1;
// Consumer signalling Producer of completion
// Ensures all blocks in the Same Row and Column get notifed.
CUTLASS_DEVICE
void consumer_release(uint32_t stage, uint32_t skip) {
detail::pipeline_check_is_consumer(params_.role);
uint64_t* smem_ptr = reinterpret_cast<uint64_t*>(&empty_barrier_ptr_[stage]);
if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1
if (!skip) {
cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_);
}
}
else {
if (!skip) {
if constexpr (cute::is_static_v<ClusterShape> and size(ClusterShape{}) == 1) {
cutlass::arch::umma_arrive(smem_ptr);
}
else {
cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_);
}
}
}
}
};
}
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cuda_runtime.h>
namespace cutlass::fmha {
struct Pow2 {
int n;
int log2_n;
explicit CUTE_DEVICE Pow2(int n) : n(n) {
#ifdef __CUDA_ARCH__
log2_n = __ffs(n) - 1;
#endif
}
template<class T>
CUTE_HOST_DEVICE T operator *(T const& b) const {
return n * b;
}
template<int N>
CUTE_HOST_DEVICE auto operator *(Int<N> const&) const {
if constexpr (N & (N - 1) == 0) {
return Pow2{n * N};
}
return n * N;
}
};
template<class T>
CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) {
return a >> b.log2_n;
}
template<class T>
CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) {
return a & (b.n - 1);
}
template<class T>
CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) {
return a < b.n;
}
CUTE_HOST_DEVICE void print(Pow2 const& a) {
printf("2^%d", a.log2_n);
}
} // end namespace cutlass::fmha
namespace cute {
template <>
struct is_integral<cutlass::fmha::Pow2> : true_type {};
} // end namespace cute
#pragma once
#include "cutlass/numeric_types.h"
#include "helper.h"
template <typename T>
struct cutlass_dtype {
using type = T;
};
template <>
struct cutlass_dtype<half> {
using type = cutlass::half_t;
};
template <>
struct cutlass_dtype<nv_bfloat16> {
using type = cutlass::bfloat16_t;
};
template <>
struct cutlass_dtype<__nv_fp8_e4m3> {
using type = cutlass::float_e4m3_t;
};
template <>
struct cutlass_dtype<__nv_fp8_e5m2> {
using type = cutlass::float_e5m2_t;
};
template <typename T>
using cutlass_dtype_t = typename cutlass_dtype<T>::type;
template<typename T>
struct DeviceAllocation {
T* ptr_ = nullptr;
size_t offset_ = 0;
size_t size_ = 0;
DeviceAllocation(DeviceAllocation const&) = delete;
DeviceAllocation& operator=(DeviceAllocation const&) = delete;
DeviceAllocation() = default;
DeviceAllocation(size_t size) { reset(size); }
~DeviceAllocation() { reset(); }
void reset(size_t size, size_t offset=0) {
reset();
auto ret = cudaMalloc(&ptr_, sizeof(T) * (size + offset));
assert(ret == cudaSuccess);
size_ = size;
offset_ = offset;
}
T* get() {
return ptr_ + offset_;
}
const T* get() const {
return ptr_ + offset_;
}
void reset() {
if (ptr_ != nullptr) {
auto ret = cudaFree(ptr_);
assert(ret == cudaSuccess);
}
}
size_t size() const { return size_; }
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
void copy_from_host(const T* ptr, size_t sz) {
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
assert(ret == cudaSuccess);
}
void copy_from_device(const T* ptr, size_t sz) {
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
assert(ret == cudaSuccess);
}
};
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template <class Kernel_>
class FMHA {
public:
using Kernel = Kernel_;
static int const kThreadCount = Kernel::MaxThreadsPerBlock;
/// Argument structure: User API
using Arguments = typename Kernel::Arguments;
/// Argument structure: Kernel API
using Params = typename Kernel::Params;
private:
/// Kernel API parameters object
Params params_;
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (Kernel::can_implement(args)) {
return Status::kSuccess;
}
else {
return Status::kInvalid;
}
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
workspace_bytes += Kernel::get_workspace_size(args);
return workspace_bytes;
}
/// Computes the grid shape
static dim3
get_grid_shape(Params const& params) {
return Kernel::get_grid_shape(params);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = Kernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<Kernel>,
Kernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("FMHA::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = Kernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
// Initialize the Params structure
params_ = Kernel::to_underlying_arguments(args, workspace);
if (is_initialized()) return Status::kSuccess;
// account for dynamic smem capacity if needed
int smem_size = Kernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
is_initialized(true);
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
params_ = Kernel::to_underlying_arguments(args, workspace);
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("FMHA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = get_grid_shape(params);
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
cute::size<1>(typename Kernel::ClusterShape{}),
cute::size<2>(typename Kernel::ClusterShape{}));
void const* kernel = (void const*) device_kernel<Kernel>;
void* kernel_params[] = {&params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result && Status::kSuccess == launch_result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::device
////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cute/tensor.hpp"
#include "../device/fmha.hpp"
#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp"
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
#include "../kernel/fmha_kernel_bwd_convert.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class Element,
class ElementAccumulator,
class TileShape,
bool IsMla,
class Mask
>
class Sm100FmhaBwd {
public:
/// Argument structure: User API
struct Arguments {
// Q K D D_VO HB
ProblemShape problem_shape;
const Element* ptr_Q;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
const Element* ptr_K;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
const Element* ptr_V;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
const ElementAccumulator* ptr_LSE;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
Element* ptr_dQ;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
Element* ptr_dK;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
Element* ptr_dV;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
ElementAccumulator softmax_scale;
cutlass::KernelHardwareInfo hw_info;
};
using OperationSumOdO = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<ProblemShape, Element, ElementAccumulator>
>;
using OperationConvert = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
>;
using OperationMha= cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
>
>;
using OperationMla = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
>
>;
using Operation = std::conditional_t<IsMla, OperationMla, OperationMha>;
using Kernel = typename Operation::Kernel;
struct Params {
OperationSumOdO op_sum_OdO;
Operation op;
OperationConvert op_convert;
ElementAccumulator* dQ_acc;
size_t dQ_acc_size;
};
private:
Params params_;
static typename OperationSumOdO::Arguments to_sum_OdO_arguments(
Arguments const& args,
ElementAccumulator* sum_odo = nullptr,
ElementAccumulator* scaled_lse = nullptr) {
using namespace cute;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
auto log2_e = log2f(expf(1.0f));
return typename OperationSumOdO::Arguments {
args.problem_shape,
args.ptr_O, args.stride_O,
args.ptr_dO, args.stride_dO,
sum_odo, stride_sum_OdO,
args.ptr_LSE, args.stride_LSE,
scaled_lse, stride_scaled_lse,
-1.0f, -log2_e
};
}
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
using namespace cute;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
return typename OperationConvert::Arguments {
args.problem_shape,
src, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, stride_src_dQ,
args.ptr_dQ, args.stride_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV,
args.softmax_scale
};
}
static typename Operation::Arguments to_bwd_arguments(
Arguments const& args,
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
return typename Operation::Arguments{
args.problem_shape,
{ args.ptr_Q, args.stride_Q,
args.ptr_K, args.stride_K,
args.ptr_V, args.stride_V,
args.ptr_dO, args.stride_dO,
scaled_lse, stride_scaled_lse,
sum_OdO, stride_sum_OdO,
dQ_acc, stride_dQ,
args.softmax_scale },
{ args.ptr_dK, args.stride_dK,
args.ptr_dV, args.stride_dV },
args.hw_info
};
}
public:
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
Status status = Status::kSuccess;
status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));
if (status != Status::kSuccess) {
return status;
}
status = OperationConvert::can_implement(to_convert_arguments(args));
if (status != Status::kSuccess) {
return status;
}
status = Operation::can_implement(to_bwd_arguments(args));
if (status != Status::kSuccess) {
return status;
}
return status;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
size_t workspace_bytes = 0;
// OdO vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
// scaled LSE vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
// FP32 versions of outputs that are churned (start off with Q only)
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
return workspace_bytes;
}
/// Initializes state from arguments.
Status
initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
params_.dQ_acc = dQ_acc;
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse);
auto args_convert = to_convert_arguments(args, dQ_acc);
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
params_.op_convert.initialize(args_convert, nullptr, stream);
auto args_bwd = to_bwd_arguments(
args, sum_OdO, args_sum_OdO.stride_sum_OdO,
scaled_lse, args_sum_OdO.stride_scaled_lse,
dQ_acc, args_convert.stride_src_dQ
);
params_.op.initialize(args_bwd, nullptr, stream);
return Status::kSuccess;
}
/// Initializes state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream);
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
Status result = Status::kSuccess;
result = params.op_sum_OdO.run(stream);
if (result != Status::kSuccess) {
return result;
}
auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);
if (cuda_result != cudaSuccess) {
return Status::kErrorInternal;
}
result = params.op.run(stream);
if (result != Status::kSuccess) {
return result;
}
result = params.op_convert.run(stream);
if (result != Status::kSuccess) {
return result;
}
return Status::kSuccess;
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////
#include <Python.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/library.h>
#include <cuda_bf16.h>
#include "common/mask.cuh"
#include "common/utils.hpp"
#include "fmha_cutlass_bwd_sm100.cuh"
template<class Mask, class Varlen, class Element, class ElementOut, class Mla>
void call_run_fmha_bwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen,
[[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, [[maybe_unused]] Mla mla,
at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
float softmax_scale, int max_seqlen_q, int total_seqlen_kv) {
static constexpr bool IsVarlen = std::is_same_v<Varlen, true_type>;
static constexpr bool IsMla = std::is_same_v<Mla, true_type>;
using TileShape = std::conditional_t<IsMla, Shape<_64, _128, _192, _128>, Shape<_128, _128, _128, _128>>;
run_fmha_bwd<Element, IsVarlen, IsMla, TileShape, Mask>(workspace_buffer, d_o, q, k, v, o, lse,
cumulative_seqlen_q, cumulative_seqlen_kv,
dq, dk, dv,
softmax_scale, max_seqlen_q, total_seqlen_kv);
}
void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen) {
const c10::cuda::OptionalCUDAGuard device_guard(q.device());
int head_dim_qk = q.size(-1);
int head_dim_vo = v.size(-1);
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
auto scalar_type_in = q.scalar_type();
auto scalar_type_out = o.scalar_type();
if(scalar_type_in == at::ScalarType::BFloat16 && scalar_type_out == at::ScalarType::BFloat16) {
using Element = cutlass::bfloat16_t;
using ElementOut = cutlass::bfloat16_t;
auto apply_config = [&](auto fn) {
if (mask_mode == MaskMode::kCausal) {
if(is_varlen) {
fn(CausalForBackwardMask<false>{}, cute::true_type{}, Element{}, ElementOut{});
} else {
fn(CausalForBackwardMask<false>{}, cute::false_type{}, Element{}, ElementOut{});
}
}
else {
if(is_varlen) {
fn(ResidualMaskForBackward{}, cute::true_type{}, Element{}, ElementOut{});
} else {
fn(ResidualMaskForBackward{}, cute::false_type{}, Element{}, ElementOut{});
}
}
};
apply_config([&](auto mask, auto varlen, auto in, auto out) {
if (head_dim_qk == 192 && head_dim_vo == 128) {
call_run_fmha_bwd(mask, varlen, in, out, true_type{}, workspace_buffer, d_o, q, k, v, o, lse,
cumulative_seqlen_q, cumulative_seqlen_kv,
dq, dk, dv,
softmax_scale, max_seqlen_q, max_seqlen_kv);
} else if (head_dim_qk == 128 && head_dim_vo == 128) {
call_run_fmha_bwd(mask, varlen, in, out, false_type{}, workspace_buffer, d_o, q, k, v, o, lse,
cumulative_seqlen_q, cumulative_seqlen_kv,
dq, dk, dv,
softmax_scale, max_seqlen_q, max_seqlen_kv); }
else {
std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk << " head_dim_vo=" << head_dim_vo << std::endl;
}
});
} else {
FLASH_MLA_ASSERT(false);
}
}
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <iostream>
#include <random>
#include <regex>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/kernel_hardware_info.h>
#include <cutlass/util/command_line.h>
#include <cutlass/util/distribution.h>
#include <cutlass/util/reference/device/tensor_fill.h>
#include "common/utils.hpp"
#include "collective/fmha_fusion.hpp"
#include "device/fmha_device_bwd.hpp"
using namespace cute;
using namespace cutlass::fmha::kernel;
using namespace cutlass::fmha::collective;
using namespace cutlass::fmha;
using namespace cutlass;
template<
class DType,
bool kIsVarlen,
bool kIsMla,
class TileShape,
class ActiveMask
>
struct BwdRunner {
using Element = DType;
using ElementAccumulator = float;
// Q K D D_VO (H B)
using ProblemShape = std::conditional_t<
kIsVarlen,
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<int, int>>,
cute::tuple<int, int, int, int, cute::tuple<int, int>>
>;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, kIsMla, ActiveMask>;
using TensorStride = Stride<int, _1, Stride<int, int>>;
using StrideQ = TensorStride; // Seq DQK (H B)
using StrideK = TensorStride; // Seq DQK (H B)
using StrideV = TensorStride; // Seq DVO (H B)
using StrideO = TensorStride; // Seq DVO (H B)
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
// Backwards specific
using StrideDQ = TensorStride;
using StrideDK = TensorStride; // Seq DQK (H B)
using StrideDV = TensorStride; // Seq DVO (H B)
using StrideDO = TensorStride;
static void run(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
float softmax_scale, int max_seqlen_q, int max_seqlen_kv) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
ProblemShape problem_shape;
cute::tuple<int, int, int, int, cute::tuple<int, int>> tensor_shape;
int d = q.size(-1);
int d_vo = v.size(-1);
int batch_size = cumulative_seqlen_q.size(0) - 1;
int num_qo_heads = q.size(1);
int total_seqlen_q = q.size(0);
int total_seqlen_kv = k.size(0);
//varlen: q: [Q, H, D]
//fixedlen: q: [B, H, Q, D]
if constexpr (kIsVarlen) {
problem_shape = cute::make_tuple(
VariableLength{max_seqlen_q, static_cast<int*>(cumulative_seqlen_q.data_ptr()), total_seqlen_q},
VariableLength{max_seqlen_kv, static_cast<int*>(cumulative_seqlen_kv.data_ptr()), total_seqlen_kv},
d, d_vo, cute::make_tuple(num_qo_heads, batch_size));
tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, d, d_vo, make_shape(num_qo_heads, 1));
} else {
int q_len = total_seqlen_q / batch_size;
int kv_len = total_seqlen_kv / batch_size;
problem_shape = cute::make_tuple(q_len, kv_len, d, d_vo, cute::make_tuple(num_qo_heads, batch_size));
tensor_shape = problem_shape;
}
auto [Q, K, D, D_VO, HB] = tensor_shape;
auto [H, B] = HB;
int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2);
int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2);
int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2);
int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2);
int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1);
int dq_stride0 = dq.stride(0), dq_stride1 = dq.stride(1), dq_stride2 = dq.stride(2);
int dk_stride0 = dk.stride(0), dk_stride1 = dk.stride(1), dk_stride2 = dk.stride(2);
int dv_stride0 = dv.stride(0), dv_stride1 = dv.stride(1), dv_stride2 = dv.stride(2);
int do_stride0 = d_o.stride(0), do_stride1 = d_o.stride(1), do_stride2 = d_o.stride(2);
TORCH_CHECK(q_stride2 == 1);
TORCH_CHECK(k_stride2 == 1);
TORCH_CHECK(v_stride2 == 1);
TORCH_CHECK(o_stride2 == 1);
TORCH_CHECK(lse_stride0 == 1);
TORCH_CHECK(dq_stride2 == 1);
TORCH_CHECK(dk_stride2 == 1);
TORCH_CHECK(dv_stride2 == 1);
TORCH_CHECK(do_stride2 == 1);
StrideQ stride_Q = make_stride(q_stride0, _1{}, make_stride(q_stride1, B == 1 ? 0 : q_stride0*Q));
StrideK stride_K = make_stride(k_stride0, _1{}, make_stride(k_stride1, B == 1 ? 0 : k_stride0*K));
StrideV stride_V = make_stride(v_stride0, _1{}, make_stride(v_stride1, B == 1 ? 0 : v_stride0*K));
StrideO stride_O = make_stride(o_stride0, _1{}, make_stride(o_stride1, B == 1 ? 0 : o_stride0*Q));
StrideLSE stride_LSE = make_stride(_1{}, make_stride(lse_stride1, B == 1 ? 0 : Q));
StrideDQ stride_dQ = make_stride(dq_stride0, _1{}, make_stride(dq_stride1, B == 1 ? 0 : dq_stride0*Q));
StrideDK stride_dK = make_stride(dk_stride0, _1{}, make_stride(dk_stride1, B == 1 ? 0 : dk_stride0*K));
StrideDV stride_dV = make_stride(dv_stride0, _1{}, make_stride(dv_stride1, B == 1 ? 0 : dv_stride0*K));
StrideDO stride_dO = make_stride(do_stride0, _1{}, make_stride(do_stride1, B == 1 ? 0 : do_stride0*Q));
typename Operation::Arguments arguments{
problem_shape,
(static_cast<Element*>(q.data_ptr())), stride_Q,
(static_cast<Element*>(k.data_ptr())), stride_K,
(static_cast<Element*>(v.data_ptr())), stride_V,
(static_cast<Element*>(o.data_ptr())), stride_O,
(static_cast<ElementAccumulator*>(lse.data_ptr())), stride_LSE,
(static_cast<Element*>(d_o.data_ptr())), stride_dO,
(static_cast<Element*>(dq.data_ptr())), stride_dQ,
(static_cast<Element*>(dk.data_ptr())), stride_dK,
(static_cast<Element*>(dv.data_ptr())), stride_dV,
static_cast<ElementAccumulator>(softmax_scale),
hw_info
};
Operation op;
size_t workspace_size = 0;
workspace_size = Operation::get_workspace_size(arguments);
DeviceAllocation<uint8_t> workspace(workspace_size);
uint8_t* workspace_ptr = workspace.get();
CUTLASS_CHECK(op.can_implement(arguments));
CUTLASS_CHECK(op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));
}
};
template <typename DType, bool kIsVarlen, bool kIsMla, typename TileShape, typename Mask>
void run_fmha_bwd(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
float softmax_scale, int max_seqlen_q, int total_seqlen_kv) {
BwdRunner<DType, kIsVarlen, kIsMla, TileShape, Mask>::run(workspace_buffer, d_o, q, k, v, o, lse,
cumulative_seqlen_q, cumulative_seqlen_kv,
dq, dk, dv,
softmax_scale, max_seqlen_q, total_seqlen_kv);
}
#include "common/mask.cuh"
#include "common/utils.hpp"
#include "fmha_cutlass_fwd_sm100.cuh"
#include <Python.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_bf16.h>
#include <torch/library.h>
template <class Mask, class Varlen, class Element, class ElementOut, class Mla>
void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen,
[[maybe_unused]] Element in, [[maybe_unused]] ElementOut out,
[[maybe_unused]] Mla mla, at::Tensor workspace_buffer, at::Tensor q,
at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q,
at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse,
float softmax_scale, int max_seqlen_q, int max_seqlen_kv) {
static constexpr bool IsVarlen = std::is_same_v<Varlen, true_type>;
static constexpr bool IsMla = std::is_same_v<Mla, true_type>;
static constexpr bool IsCausalMask = std::is_same_v<Mask, CausalMask<false>>;
using Option = std::conditional_t<IsCausalMask, Option<Tag::kIsPersistent, false_type>,
Option<Tag::kIsPersistent, true_type>>;
run_fmha_fwd<Element, ElementOut, IsVarlen, IsMla, Mask, Option>(
workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse,
softmax_scale, max_seqlen_q, max_seqlen_kv);
}
void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor cumulative_seqlen_q,
at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse,
int mask_mode_code, float sm_scale, int max_seqlen_q,
int max_seqlen_kv, bool is_varlen) {
const c10::cuda::OptionalCUDAGuard device_guard(q.device());
CHECK(q.scalar_type() == k.scalar_type());
auto scalar_type_in = q.scalar_type();
auto scalar_type_out = o.scalar_type();
int head_dim_qk = q.size(-1);
int head_dim_vo = v.size(-1);
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
if (scalar_type_in == at::ScalarType::BFloat16 &&
scalar_type_out == at::ScalarType::BFloat16) {
using Element = cutlass::bfloat16_t;
using ElementOut = cutlass::bfloat16_t;
auto apply_config = [&](auto fn) {
if (mask_mode == MaskMode::kCausal) {
if (is_varlen) {
fn(CausalMask<false>{}, cute::true_type{}, Element{}, ElementOut{});
} else {
fn(CausalMask<false>{}, cute::false_type{}, Element{}, ElementOut{});
}
} else {
if (is_varlen) {
fn(ResidualMask{}, cute::true_type{}, Element{}, ElementOut{});
} else {
fn(ResidualMask{}, cute::false_type{}, Element{}, ElementOut{});
}
}
};
apply_config([&](auto mask, auto varlen, auto in, auto out) {
if (head_dim_qk == 192 && head_dim_vo == 128) {
call_run_fmha_fwd(mask, varlen, in, out, true_type{}, workspace_buffer, q, k, v,
cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale,
max_seqlen_q, max_seqlen_kv);
} else if (head_dim_qk == 128 && head_dim_vo == 128) {
call_run_fmha_fwd(mask, varlen, in, out, false_type{}, workspace_buffer, q, k, v,
cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale,
max_seqlen_q, max_seqlen_kv);
} else {
std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk
<< " head_dim_vo=" << head_dim_vo << std::endl;
}
});
} else {
FLASH_MLA_ASSERT(false);
}
}
#pragma once
#include "collective/fmha_fusion.hpp"
#include "collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp"
#include "collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp"
#include "collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.h"
#include "device/fmha.hpp"
#include "kernel/fmha_causal_tile_scheduler.hpp"
#include "kernel/fmha_options.hpp"
#include "kernel/fmha_tile_scheduler.hpp"
#include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp"
#include <torch/library.h>
#include <c10/cuda/CUDAStream.h>
using namespace cute;
using namespace cutlass::fmha::collective;
using namespace cutlass::fmha::kernel;
using namespace cutlass::fmha::device;
struct FmhaOptions {
int b = 1;
int h = 1;
int h_k = 1;
int q = 256;
int k = 256;
int d = 128;
};
struct MlaOptions {
int b = 1;
int h = 1;
int h_k = 1;
int q = 256;
int k = 256;
int dl = 128; // headdim latent
int dr = 64; // headdim rope
};
template <bool kIsMla, bool kIsMaskTileSchedulerValid, bool kIsVarlen, class Element_,
class ElementOut_, class ActiveMask, class... KernelOptions>
struct FwdRunner {
using Element = Element_;
using ElementAccumulatorQK = float;
using ElementAccumulatorPV = float;
using ElementOut = ElementOut_;
using HeadDimLatent = _128;
using HeadDim = Shape<HeadDimLatent, _64>;
using TileShapeMla = Shape<_256, _128, HeadDim>;
using TileShapeFmha = Shape<_256, _128, _128>;
using TileShape = std::conditional_t<kIsMla, TileShapeMla, TileShapeFmha>;
using ProblemShapeRegular = std::conditional_t<
kIsMla,
cute::tuple<int, int, cute::tuple<int, int>, cute::tuple<cute::tuple<int, int>, int>>,
cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>>;
using ProblemShapeVarlen =
std::conditional_t<kIsMla,
cute::tuple<VariableLength, VariableLength, cute::tuple<int, int>,
cute::tuple<cute::tuple<int, int>, int>>,
cute::tuple<VariableLength, VariableLength, int,
cute::tuple<cute::tuple<int, int>, int>>>;
using ProblemShapeType =
std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>;
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>;
using StrideV = StrideK;
using StrideO = StrideQ;
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>;
static constexpr bool kIsPersistent =
find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;
using TileScheduler = std::conditional_t<
kIsPersistent,
std::conditional_t<std::is_same_v<ActiveMask, CausalMask<false>> ||
std::is_same_v<ActiveMask, CausalMask<true>>,
cutlass::fmha::kernel::CausalPersistentTileScheduler,
cutlass::fmha::kernel::PersistentTileScheduler>,
std::conditional_t<kIsMaskTileSchedulerValid,
cutlass::fmha::kernel::CausalIndividualTileScheduler,
cutlass::fmha::kernel::IndividualTileScheduler>>;
static constexpr bool IsOrderLoadEpilogue =
kIsPersistent && (sizeof(Element) == sizeof(ElementOut));
using OrderLoadEpilogue = std::conditional_t<IsOrderLoadEpilogue, true_type, false_type>;
using MainloopMla = cutlass::fmha::collective::Sm100MlaFwdMainloopTmaWarpspecialized<
Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeMla, StrideQ, StrideK,
StrideV, ActiveMask, Shape<_2, _1, _1>, OrderLoadEpilogue>;
using OperationMla =
cutlass::fmha::device::FMHA<cutlass::fmha::kernel::Sm100FmhaFwdKernelTmaWarpspecialized<
ProblemShapeType, MainloopMla,
cutlass::fmha::collective::Sm100FmhaFwdEpilogueTmaWarpspecialized<
ElementOut, ElementAccumulatorPV, typename MainloopMla::TileShapePV, StrideO,
StrideLSE, OrderLoadEpilogue>,
TileScheduler, cutlass::fmha::kernel::Sm100MlaFwdCtxKernelWarpspecializedSchedule>>;
using MainloopFmha = cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized<
Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeFmha, StrideQ, StrideK,
StrideV, ActiveMask>;
using OperationFmha =
cutlass::fmha::device::FMHA<cutlass::fmha::kernel::Sm100FmhaFwdKernelTmaWarpspecialized<
ProblemShapeType, MainloopFmha,
cutlass::fmha::collective::Sm100FmhaFwdEpilogueTmaWarpspecialized<
ElementOut, ElementAccumulatorPV, typename MainloopFmha::TileShapePV, StrideO,
StrideLSE>,
TileScheduler>>;
using Mainloop = std::conditional_t<kIsMla, MainloopMla, MainloopFmha>;
using Operation = std::conditional_t<kIsMla, OperationMla, OperationFmha>;
//
// Data members
//
/// Initialization
StrideQ stride_Q;
StrideK stride_K;
StrideV stride_V;
StrideO stride_O;
StrideLSE stride_LSE;
template <class ProblemShape>
auto initialize_varlen(const ProblemShape &problem_size, int max_seqlen_q, int max_seqlen_kv,
int total_seqlen_q, int total_seqlen_kv) {
int num_batches = get<3, 1>(problem_size);
ProblemShape problem_size_for_init = problem_size;
get<3, 1>(problem_size_for_init) = 1;
get<0>(problem_size_for_init) = total_seqlen_q;
get<1>(problem_size_for_init) = total_seqlen_kv;
ProblemShapeType problem_size_for_launch;
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
get<2>(problem_size_for_launch) = get<2>(problem_size);
get<3>(problem_size_for_launch) = get<3>(problem_size);
return cute::make_tuple(problem_size_for_init, problem_size_for_launch);
}
template <class Options>
static constexpr auto get_problem_shape(const Options &options) {
int h_r = options.h / options.h_k;
if constexpr (std::is_same_v<Options, MlaOptions>) {
return cute::make_tuple(options.q, options.k, cute::make_tuple(options.dl, options.dr),
cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b));
} else {
return cute::make_tuple(options.q, options.k, options.d,
cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b));
}
}
template <class Options>
ProblemShapeType initialize(const Options &options, int max_seqlen_q, int max_seqlen_kv,
int total_seqlen_q, int total_seqlen_kv,
void *cumulative_length_q, void *cumulative_length_kv) {
assert(options.h % options.h_k == 0);
auto problem_shape_in = get_problem_shape(options);
ProblemShapeType problem_shape;
decltype(problem_shape_in) problem_size;
if constexpr (kIsVarlen) {
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(
problem_shape_in, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv);
problem_shape = problem_shape_launch;
problem_size = problem_shape_init;
} else {
problem_size = problem_shape_in;
problem_shape = problem_shape_in;
}
auto get_head_dimension = [&]() {
if constexpr (rank_v<decltype(get<2>(problem_shape))> == 2) {
return cute::make_tuple(size<2, 0>(problem_shape) + size<2, 1>(problem_shape),
size<2, 0>(problem_shape));
} else {
return cute::make_tuple(size<2>(problem_size), size<2>(problem_size));
}
};
if constexpr (kIsVarlen) {
get<0>(problem_shape).cumulative_length = static_cast<int *>(cumulative_length_q);
get<1>(problem_shape).cumulative_length = static_cast<int *>(cumulative_length_kv);
}
return problem_shape;
}
auto get_arguments(const ProblemShapeType &problem_shape,
const cutlass::KernelHardwareInfo &hw_info, float scale_softmax,
void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr,
void *cumulative_length_q, void *cumulative_length_kv) {
auto problem_shape_ = problem_shape;
if constexpr (kIsVarlen) {
get<0>(problem_shape_).cumulative_length = static_cast<int *>(cumulative_length_q);
get<1>(problem_shape_).cumulative_length = static_cast<int *>(cumulative_length_kv);
}
typename Operation::Arguments arguments{
problem_shape_,
{static_cast<Element *>(q_ptr), stride_Q, static_cast<Element *>(k_ptr), stride_K,
static_cast<Element *>(v_ptr), stride_V, scale_softmax},
{static_cast<ElementOut *>(o_ptr), stride_O,
static_cast<ElementAccumulatorPV *>(lse_ptr), stride_LSE},
hw_info};
return arguments;
}
template <class Options>
void run(const Options &options, const cutlass::KernelHardwareInfo &hw_info, at::Tensor q,
at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, float scale_softmax,
at::Tensor workspace, at::Tensor cumulative_seqlen_q,
at::Tensor cumulative_seqlen_kv, int max_seqlen_q, int max_seqlen_kv) {
int total_seqlen_q = q.size(0);
int total_seqlen_kv = k.size(0);
ProblemShapeType problem_shape =
initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv,
cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr());
int SQ = size<0>(problem_shape);
int SK = size<1>(problem_shape);
int B = size<3, 1>(problem_shape);
int H = size<3, 0>(problem_shape);
int H_K = size<3, 0, 1>(problem_shape);
int H_Q = size<3, 0, 0>(problem_shape);
int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2);
int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2);
int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2);
int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2);
int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1);
TORCH_CHECK(q_stride2 == 1);
TORCH_CHECK(k_stride2 == 1);
TORCH_CHECK(v_stride2 == 1);
TORCH_CHECK(o_stride2 == 1);
TORCH_CHECK(lse_stride0 == 1);
stride_Q = make_stride(q_stride0, _1{}, make_stride(make_stride(q_stride1, H_Q * q_stride1), SQ * q_stride0));
stride_O = make_stride(o_stride0, _1{}, make_stride(make_stride(o_stride1, H_Q * o_stride1), SQ * o_stride0));
stride_K = make_stride(k_stride0, _1{}, make_stride(make_stride(_0{}, k_stride1), SK * k_stride0));
stride_V = make_stride(v_stride0, _1{}, make_stride(make_stride(_0{}, v_stride1), SK * v_stride0));
stride_LSE = make_stride(_1{}, make_stride(make_stride(lse_stride1, lse_stride1 * H_Q), SQ));
if constexpr (kIsVarlen) {
get<2, 1>(stride_Q) = 0;
get<2, 1>(stride_K) = 0;
get<2, 1>(stride_V) = 0;
get<2, 1>(stride_O) = 0;
get<1, 1>(stride_LSE) = 0;
}
typename Operation::Arguments arguments =
get_arguments(problem_shape, hw_info, scale_softmax, q.data_ptr(), k.data_ptr(),
v.data_ptr(), o.data_ptr(), lse.data_ptr(),
cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr());
Operation op;
// size_t workspace_size = 0;
// workspace_size = Operation::get_workspace_size(arguments);
// todo: if use workspace, need check workspace size first.
// we don't use workspace in current version.
CUTLASS_CHECK(op.can_implement(arguments));
CUTLASS_CHECK(op.initialize(arguments, nullptr));
CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));
}
};
template <class DTypeIn, class DTypeOut, bool kIsVarlen, bool kIsMla, class ActiveMask,
class... KernelOptions>
void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o,
at::Tensor lse, float scale_softmax, int max_seqlen_q, int max_seqlen_kv) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
auto get_options = [&]() {
if constexpr (kIsMla) {
MlaOptions options;
options.b = cumulative_seqlen_q.size(0) - 1;
options.h = q.size(1);
options.h_k = k.size(1);
options.q = q.size(0) / options.b;
options.k = k.size(0) / options.b;
options.dl = v.size(-1);
options.dr = q.size(-1) - v.size(-1);
return options;
} else {
FmhaOptions options;
options.b = cumulative_seqlen_q.size(0) - 1;
options.h = q.size(1);
options.h_k = k.size(1);
options.q = q.size(0) / options.b;
options.k = k.size(0) / options.b;
options.d = q.size(-1);
return options;
}
};
auto options = get_options();
if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 &&
(!std::is_same_v<ActiveMask, NoMask>)) {
FwdRunner<kIsMla, true, kIsVarlen, DTypeIn, DTypeOut, ActiveMask, KernelOptions...> runner;
runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q,
cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv);
} else {
FwdRunner<kIsMla, false, kIsVarlen, DTypeIn, DTypeOut, ActiveMask, KernelOptions...> runner;
runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q,
cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment