Commit e2e0225c authored by zhanghj2's avatar zhanghj2
Browse files

空kernel可以编译通过

parent 48c6dc42
#pragma once
#include "kernel.h"
#include <cuda_fp8.h>
#include <cutlass/barrier.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "defines.h"
#include "params.h"
namespace sm100::decode::head64 {
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::NamedBarrier;
using e8m0 = __nv_fp8_e8m0;
using e4m3 = cutlass::float_e4m3_t;
using namespace cute;
enum NamedBarriers : uint32_t {
main_loop_sync = 0,
wg0_sync = 1,
wg0_warp02_sync = 2,
wg0_warp13_sync = 3,
everyone_sync = 4
};
template<ModelType MODEL_TYPE>
struct KernelTemplate {
static constexpr int D_Q = MODEL_TYPE == ModelType::V32 ? 576 : 512;
static constexpr int D_K = D_Q;
static constexpr int D_V = 512;
static constexpr int D_NOPE = MODEL_TYPE == ModelType::V32 ? 512 : 448;
static constexpr int D_ROPE = 64;
static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;
static constexpr bool V_HAVE_ROPE = MODEL_TYPE == ModelType::V32 ? false : true;
static constexpr int NUM_SCALES_EACH_TOKEN = MODEL_TYPE == ModelType::V32 ? 4 : 8; // Padding is included
static constexpr int TMA_K_STRIDE = MODEL_TYPE == ModelType::V32 ? D_NOPE+2*D_ROPE+4*(D_NOPE/QUANT_TILE_SIZE) : D_NOPE+2*D_ROPE; // Stride of K's tensormap. This stride must 1) be a factor of the actual stride between tokens 2) large enough to cover the entire KV cache. Since TMA copy's coordinate can only be 32bit signed integers, this number must >= 128, perferrably >= 256. So we set this to 656 for V32 and 576 for MODEL1. Extra padding may be necessary for KV blocks.
static_assert(D_NOPE + D_ROPE == D_Q);
static_assert(V_HAVE_ROPE ? (D_NOPE + D_ROPE == D_V) : (D_NOPE == D_V));
static constexpr int B_H = 64;
static constexpr int B_TOPK = 64;
static constexpr int NUM_BUFS = 2;
static constexpr int NUM_INDEX_BUFS = 4; // Number of buffers for indices (tma_coords) & is_token_valid & scales
static constexpr int NUM_THREADS = 128*3; // 128 exp + 1/32 utcmma + 1/32 raw KV producer + 1/32 rope producer + 32 index+scale+valid_mask producer + 128 dequant
static constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN
static constexpr int D_Q_SW128 = 512;
static constexpr int D_Q_SW64 = MODEL_TYPE == ModelType::V32 ? 64 : 0;
static_assert(D_Q_SW128 + D_Q_SW64 == D_Q);
static constexpr int K_ROPE_SW = MODEL_TYPE == ModelType::V32 ? 64 : 128; // RoPE part stored in SW64 (for V32) or SW128 (for MODEL1), in bytes
template<
typename Shape_Q_SW128, typename TMA_Q_SW128,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q_SW128 shape_Q_SW128; TMA_Q_SW128 tma_Q_SW128;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_q_sw64; // Invalid if D_Q_SW64 == 0
CUtensorMap tensor_map_kv_nope;
CUtensorMap tensor_map_kv_rope;
CUtensorMap tensor_map_extra_kv_nope;
CUtensorMap tensor_map_extra_kv_rope;
};
// Tensor memory columns
struct tmem_cols {
// 0 ~ 256: output
// 256 ~ 256 + 64*D_Q/256: Q
// 400 ~ 464: P
static constexpr int O = 0;
static constexpr int Q = 256;
static constexpr int Q_Tail = 256 + B_H*D_NOPE/2/128;
static constexpr int P = 400;
};
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<NUM_TILES*64>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutQ_SW128 = SmemLayoutQTiles<D_Q_SW128/64>;
using SmemLayoutOBuf = decltype(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<D_V>>{}
));
using SmemLayoutOBuf_TMA = decltype(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64>>{}
)); // A TMA tile
static_assert(D_V == 512);
using SmemLayoutOAccumBuf = Layout<
Shape<Int<B_H>, Int<D_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>;
using SmemLayoutS = decltype(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles_SW128 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles_DualGemm_SW128 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H*2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed_SW128 = decltype(composition(
SmemLayoutKTiles_SW128<NUM_TILES>{},
Layout<
Shape<Int<64*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles_SW64 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H>, Int<32*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles_DualGemm_SW64 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H*2>, Int<32*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed_SW64 = decltype(composition(
SmemLayoutKTiles_SW64<NUM_TILES>{},
Layout<
Shape<Int<32*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
struct SharedMemoryPlan {
union {
struct {
array_aligned<bf16, cosize_v<SmemLayoutQ_SW128>> q;
bf16 q_sw64[B_H*D_Q_SW64]; // NOTE D_Q_SW64 may be 0 but array_aligned<bf16, 0> will have a size of 16, so we use array here. The former tensor (`q`) promises its alignment.
union {
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> o_buf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> o_accum_buf;
} o;
} qo;
struct {
struct {
array_aligned<bf16, B_H*D_NOPE> nope; // NoPE part, dequantized
array_aligned<bf16, B_H*D_ROPE> rope; // RoPE part, dequantized. SW64 in v32 mode, SW128 in MODEL1 mode
} dequant[NUM_BUFS];
static_assert(sizeof(dequant) >= sizeof(bf16) * (B_H*D_Q)); // So that Q does not covers raw_nope
array_aligned<e4m3, B_H*D_NOPE> raw_nope[NUM_BUFS]; // Raw (quantized) NoPE part
} kv;
} u;
union {
float4 p_exchange_buf[4][16 * B_TOPK / 4];
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
} s_p;
CUTE_ALIGNAS(16) float rowwise_max_buf[128];
char is_token_valid[NUM_INDEX_BUFS][B_TOPK/8];
int tma_coord[NUM_INDEX_BUFS][B_TOPK];
e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN];
array_aligned<uint32_t, 1> tmem_start_addr;
transac_bar_t bar_last_store_done;
transac_bar_t bar_q_tma, bar_q_utccp;
transac_bar_t bar_rope_ready[NUM_BUFS];
transac_bar_t bar_nope_ready[NUM_BUFS];
transac_bar_t bar_raw_ready[NUM_BUFS], bar_raw_free[NUM_BUFS];
transac_bar_t bar_valid_coord_scale_ready[NUM_INDEX_BUFS], bar_valid_coord_scale_free[NUM_INDEX_BUFS];
transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS], bar_sv_done[NUM_BUFS];
};
using TiledMMA_P = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK*2, UMMA::Major::K, UMMA::Major::K>{}
)); // *2 for dual gemm
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{}
));
template<typename TmaParam>
static __device__ void
flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams &params, const TmaParam &tma_params);
static void run(const SparseAttnDecodeParams &params);
};
}
\ No newline at end of file
#include "../kernel.cuh"
namespace sm100::decode::head64 {
template
void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1>(const SparseAttnDecodeParams &params);
}
#include "../kernel.cuh"
namespace sm100::decode::head64 {
template
void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32>(const SparseAttnDecodeParams &params);
}
This diff is collapsed.
#pragma once
#include "params.h"
namespace sm100::decode::head64 {
template<ModelType MODEL_TYPE>
void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params);
}
#pragma once
#include <cute/tensor.hpp>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include "defines.h"
namespace sm100 {
using namespace cute;
CUTE_DEVICE
int int4_max(int4 t) {
return max(max(t.x, t.y), max(t.z, t.w));
}
CUTE_DEVICE
int int4_min(int4 t) {
return min(min(t.x, t.y), min(t.z, t.w));
}
// Convert 2x fp8_e4m3 to 2x bf16 with scaling
CUTE_DEVICE
nv_bfloat162 fp8x2_to_bf16x2_with_scale(__nv_fp8x2_e4m3 data, nv_bfloat16 scale) {
// TODO Use native conversion for CUDA >= 13.1
float2 data_float2 = (float2)data;
nv_bfloat162 data_bf16x2 = __float22bfloat162_rn(data_float2);
return nv_bfloat162 {
data_bf16x2.x * scale,
data_bf16x2.y * scale
};
}
}
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/kernel_hardware_info.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<typename Atom, typename TA, typename TB, typename TC>
CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
constexpr int rA = decltype(rank(tA))::value;
constexpr int rB = decltype(rank(tB))::value;
constexpr int rC = decltype(rank(tC))::value;
static_assert(rA == 3 && rB == 3 && rC == 3);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tA); k_block++) {
cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
atom.accumulate_ = decltype(atom.accumulate_)::One;
}
}
template<typename Atom, typename TA, typename TB, typename TC>
CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
atom.accumulate_ = decltype(atom.accumulate_)::Zero;
gemm_reset_zero_acc(atom, tA, tB, tC);
}
template<class Layout, class Stages = _1>
CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
return composition(layout, prepend<decltype(rank(layout))::value>(make_layout(stages), _));
}
template<class T>
CUTE_DEVICE T warp_uniform(T a) {
return __shfl_sync(0xffffffff, a, 0);
}
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
CUTE_HOST_DEVICE constexpr
auto
to_tiled_mma_sm100_ts(
TiledMMA<MMA_Atom<
MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
cute::C<M>, cute::C<N>,
cute::integral_constant<UMMA::Major, a_major>,
cute::integral_constant<UMMA::Major, b_major>,
cute::integral_constant<UMMA::ScaleIn, a_neg>,
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
TAs...>, TMs...>) {
return TiledMMA<MMA_Atom<
MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, UMMA::Saturate::False>>,
TAs...>, TMs...>{};
}
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
CUTE_HOST_DEVICE constexpr
auto
to_tiled_mma_sm100_ts(
TiledMMA<MMA_Atom<
SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
M, N,
a_major,
b_major,
a_neg,
b_neg>,
TAs...>, TMs...>) {
return TiledMMA<MMA_Atom<
SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, UMMA::Saturate::False>,
TAs...>, TMs...>{};
}
template<uint32_t RegCount>
CUTLASS_DEVICE
void warpgroup_reg_set() {
if constexpr (RegCount < 128) {
cutlass::arch::warpgroup_reg_dealloc<RegCount>();
}
else {
cutlass::arch::warpgroup_reg_alloc<RegCount>();
}
}
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
struct NoMask {
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return ceil_div(get<1>(problem_size), get<1>(tile_shape));
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
return;
}
};
struct ResidualMask : NoMask {
using Base = NoMask;
template <class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return 1;
}
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// if the sequence length does not divide the tile size evenly
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (get<1>(pos) >= get<1>(problem_size)) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct ResidualMaskForBackward : NoMask {
using Base = NoMask;
template <class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return 1;
}
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// if the sequence length does not divide the tile size evenly
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (! elem_less(pos, select<0,1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
};
// There are two ways to do causal if N_Q != N_K
// (1) The Q is at the beginning of the matrix
// (2) The Q is at the end of the matrix
template<bool kIsQBegin = true>
struct CausalMask : NoMask {
using Base = NoMask;
static constexpr bool IsQBegin = kIsQBegin;
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// See note below on different ways to think about causal attention
// Again, we'd add the offset_q into the max_blocks_q calculation
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
} else {
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is the default setting.
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to set kIsQBegin=false
if constexpr (IsQBegin) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
} else {
const auto offset_q = get<1>(problem_size) - get<0>(problem_size);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
}
};
template<bool kIsQBegin = true>
struct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForBackward {
using Base = CausalMask<kIsQBegin>;
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is what we demonstrate here
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
int offset_q = 0;
if constexpr (!kIsQBegin) {
offset_q = get<1>(problem_size) - get<0>(problem_size);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size);
if (masked) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct VariableLength {
int max_length;
int* cumulative_length = nullptr;
int total_length = -1;
CUTE_HOST_DEVICE operator int() const {
return max_length;
}
};
template<class T> struct is_variable_length_impl : std::false_type {};
template<> struct is_variable_length_impl<VariableLength> : std::true_type {};
template<class T> constexpr bool is_variable_length_v = is_variable_length_impl<remove_cvref_t<T>>::value;
template<class Shape, class Idx>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length(Shape const& shape, Idx const& idx) {
return transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
return s;
}
});
}
template<class Shape, class Coord, class Idx>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
auto new_shape = apply_variable_length(shape, idx);
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
if constexpr (is_variable_length_v<decltype(s)>) {
return cute::make_tuple(c, s.cumulative_length[idx]);
}
else {
return c;
}
});
return cute::make_tuple(new_shape, new_coord);
}
template<class Shape, class Coord>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length_offset(Shape const& shape, Coord const& coord) {
auto idx = back(back(coord));
auto result_shape = transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
return s;
}
});
auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx];
}
else {
return _0{};
}
});
return cute::make_tuple(result_shape, result_offset);
}
} // namespace cutlass::fmha::collective
namespace cute {
template<>
struct is_integral<cutlass::fmha::collective::VariableLength> : true_type {};
CUTE_HOST_DEVICE
void print(cutlass::fmha::collective::VariableLength a) {
printf("Varlen<%d, %p>", a.max_length, a.cumulative_length);
}
}
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
namespace cutlass::fmha::collective {
template<
class Element,
class ElementAcc,
class TileShape, // Q, D, _
class StrideO, // Q, D, B
class StrideLSE_, // Q, B
class OrderLoadEpilogue = cute::false_type
>
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
using Pipeline = cutlass::PipelineAsync<2>;
// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>());
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
using ElementOut = Element;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct TensorStorage {
using SmemLayoutO = SmemLayoutO_;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct Arguments {
Element* ptr_O;
StrideO dO;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
};
using TMA_O = decltype(make_tma_copy(
SM90_TMA_STORE{},
make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}),
SmemLayoutO{}(_,_,_0{})
));
struct Params {
TMA_O tma_store_o;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
};
// FMHA and MLA have different input ProblemShapes;
// get problem_shape_O according to the input ProblemShape.
template<class ProblemShape>
CUTLASS_DEVICE static constexpr
auto get_problem_shape_O (
ProblemShape const& problem_shape) {
if constexpr (rank_v<decltype(get<2>(ProblemShape{}))> == 2) {
return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape));
} else {
return select<0,2,3>(problem_shape);
}
}
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace = nullptr) {
auto ptr_O = args.ptr_O;
StrideO dO = args.dO;
auto problem_shape_O = get_problem_shape_O(problem_shape);
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
get<0>(problem_shape_O).max_length = max(1, max_length_q);
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dO) = get<0>(dO);
get<2,1>(problem_shape_O) = max(1, max_length_q * (1 + get<2,1>(problem_shape_O)));
// offset ptr by the amount we add back in later
ptr_O -= max_length_q * get<0>(dO);
}
} else {
get<0>(problem_shape_O) = max(1, get<0>(problem_shape_O));
}
auto tma_store_o = make_tma_copy(
SM90_TMA_STORE{},
make_tensor(ptr_O, problem_shape_O, dO),
SmemLayoutO{}(_,_,_0{})
);
return {
tma_store_o,
args.ptr_LSE,
args.dLSE
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
}
const Params& params;
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE auto
store(
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& shared_storage,
Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) {
BlkCoord blk_coord = blk_coord_in;
uint32_t lane_predicate = cute::elect_one_sync();
using X = Underscore;
int o0_index = 2 * get<0>(blk_coord);
int o1_index = 2 * get<0>(blk_coord) + 1;
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape));
// offset mode 0 by (max_length - real_length)
// offset mode 3,1 by cumulative_length + real_length
// the ptr is already offset by - max_length
// so in total this achieves
int offs_0 = 0;
int offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
offs_0 = max_length_q - get<0>(problem_shape);
offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape);
get<2,1>(blk_coord) = 0;
}
}
Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p);
Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{});
Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord));
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto block_tma = params.tma_store_o.get_slice(0);
Tensor tOsO = block_tma.partition_S(sO);
Tensor tOgO = block_tma.partition_D(gO);
auto pipeline_release_state = pipeline_consumer_state;
// O1 O2
// one pipeline: O
// wait from corr, issue tma store on smem
pipeline.consumer_wait(pipeline_consumer_state);
++pipeline_consumer_state;
if (lane_predicate) {
copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index));
}
tma_store_arrive();
pipeline.consumer_wait(pipeline_consumer_state);
++pipeline_consumer_state;
if (lane_predicate) {
copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index));
}
tma_store_arrive();
tma_store_wait<1>();
pipeline.consumer_release(pipeline_release_state);
++pipeline_release_state;
tma_store_wait<0>();
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
pipeline.consumer_release(pipeline_release_state);
++pipeline_release_state;
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_fusion.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element,
class StrideQ,
class StrideK,
class StrideV,
class CollectiveMmaQK,
class CollectiveMmaPV,
class SmemLayoutQ,
class SmemLayoutK,
class SmemLayoutV,
class TensorStorage,
class PipelineQ,
class PipelineKV,
class Mask,
class TileShape
>
struct Sm100FmhaLoadTmaWarpspecialized {
using TileShapeQK = typename CollectiveMmaQK::TileShape;
using TileShapePV = typename CollectiveMmaPV::TileShape;
struct Arguments {
const Element* ptr_Q;
StrideQ dQ;
const Element* ptr_K;
StrideK dK;
const Element* ptr_V;
StrideV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
};
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
auto ptr_Q = args.ptr_Q;
auto ptr_K = args.ptr_K;
auto ptr_V = args.ptr_V;
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<2>(problem_shape_qk) = get<2>(problem_shape);
get<3>(problem_shape_qk) = get<3>(problem_shape);
}
} else {
problem_shape_qk = problem_shape;
}
get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk));
get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk));
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
problem_shape_qk,
typename CollectiveMmaQK::Arguments {
ptr_Q, dQ,
ptr_K, dK,
}, /*workspace=*/ nullptr);
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
problem_shape_pv,
typename CollectiveMmaPV::Arguments {
ptr_K, dK, // never used, dummy
ptr_V, select<1,0,2>(dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
BlkCoord blk_coord_q = blk_coord_in;
BlkCoord blk_coord_kv = blk_coord_in;
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
using X = Underscore;
// this one is only executed by one thread, no need to elect_one
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
// two pipes: Q and KV
// from Memory (prod) to TensorCore (cons)
// compute gQ, sQ
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
int q_offs_0 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
auto [tQgQ_qdl, tQsQ] = tma_partition(
params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
);
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
// compute gK, sK
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
int kv_offs_0 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
auto [tKgK_kdl, tKsK] = tma_partition(
params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
);
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
// compute gV, sV
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
auto [tVgV_dkl, tVsV] = tma_partition(
params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
);
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
// As such, it needs to be transformed as
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
uint32_t lane_predicate = cute::elect_one_sync();
// Q1
int q0_index = 2 * get<0>(blk_coord_q);
int q1_index = 2 * get<0>(blk_coord_q) + 1;
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
// K1
int k_index = 0;
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
// Q2
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
// V1
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
k_index += 1;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// Ki
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
// Vi
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
k_index += 1;
}
}
};
} // namespace cutlass::fmha::collective
This diff is collapsed.
This diff is collapsed.
#pragma once
enum class MaskMode {
kNone = 0U, // No mask
kCausal = 1U, // Causal mask
kCustom = 2U, // Custom mask
};
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment