Commit 7f67966c authored by Tri Dao's avatar Tri Dao
Browse files

FA3 initial code release

parent b4a9dd6c
......@@ -26,6 +26,42 @@ contains a partial list of places where FlashAttention is being used.
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite and credit FlashAttention if you use it.
## FlashAttention-3 beta release
FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).
Blogpost: https://tridao.me/blog/2024/flash3/
Paper: https://tridao.me/publications/flash3/flash3.pdf
![FlashAttention-3 speedup on H100 80GB SXM5 with FP16](assets/flash3_fp16_fwd.png)
This is a beta release for testing / benchmarking before we integrate that with
the rest of the repo.
Currently released:
- FP16 forward and backward
Coming soon in the next couple of days / next week:
- BF16
- Variable length (FP16, BF16)
- FP8 forward.
Requirements: H100 / H800 GPU, CUDA >= 12.3.
To install:
```sh
cd hopper
python setup.py install
```
To run the test:
```sh
export PYTHONPATH=$PWD
pytest -q -s test_flash_attn.py
```
## Installation and features
Requirements:
......
__version__ = "3.0.0.b1"
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
template <typename index_t>
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "utils.h"
namespace flash {
using namespace cute;
// template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
template <typename Ktraits>
struct CollectiveEpilogueFwd {
using Element = typename Ktraits::Element;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kHeadDim = Ktraits::kHeadDim;
// using Element = Element_;
// static constexpr int kBlockM = kBlockM_;
// static constexpr int kBlockN = kBlockN_;
// static constexpr int kHeadDim = kHeadDim_;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
// static constexpr int kNWarps = kNWarps_;
static constexpr int kNWarps = Ktraits::kNWarps;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr bool Is_WS = kNWarps >= 12;
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
// These are for storing the output tensor without TMA (e.g., for setting output to zero)
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
static constexpr int kGmemThreadsPerRow = kHeadDim / kGmemElemsPerLoad;
static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen_q, head, batch)
using TMA_O = decltype(make_tma_copy(
GmemTiledCopyOTMA{},
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), repeat_like(StrideO{}, int32_t(0)), StrideO{}),
SmemLayoutO{},
select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for O
// Host side kernel arguments
struct Arguments {
Element* ptr_O;
ShapeO const shape_O;
StrideO const stride_O;
float* ptr_LSE;
StrideLSE const stride_LSE;
};
// Device side kernel params
struct Params {
Element* ptr_O;
ShapeO const shape_O;
StrideO const stride_O;
float* ptr_LSE;
StrideLSE const stride_LSE;
TMA_O tma_store_O;
};
static Params
to_underlying_arguments(Arguments const& args) {
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
TMA_O tma_store_O = make_tma_copy(
GmemTiledCopyOTMA{},
mO,
SmemLayoutO{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for O
return {args.ptr_O, args.shape_O, args.stride_O, args.ptr_LSE, args.stride_LSE, tma_store_O};
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& epilogue_params) {
cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
}
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
CUTLASS_DEVICE void
store(Params const& epilogue_params,
FrgTensorO const& tOrO,
FrgTensorLSE const& lse,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
) {
auto [m_block, bidh, bidb] = block_coord;
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOrO_out = flash::convert_type<Element>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Make sure all WGs have finished reading V
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.shape_O);
Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{});
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
// taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO_row(_0{})) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}
if (cutlass::canonical_warp_idx_sync() == kNWarps - 1) {
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
int const lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
cute::copy(epilogue_params.tma_store_O, tOsO, tOgO);
tma_store_arrive();
}
}
}
CUTLASS_DEVICE void
store_tail() {
tma_store_wait<0>();
}
// Write 0 to output and -inf to LSE
CUTLASS_DEVICE void
store_zero(
Params const& epilogue_params,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
) {
auto [m_block, bidh, bidb] = block_coord;
Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.shape_O, epilogue_params.stride_O);
Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_fragment_like(tOgO);
clear(tOrO);
// Construct identity layout for sO
Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.shape_O); }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.shape_O) - m_block * kBlockM
);
static_assert(kBlockM <= NumMmaThreads);
if (thread_idx < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
}
};
} // namespace flash
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#include "cutlass/fast_math.h" // For cutlass::FastDivmod
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void * __restrict__ p_ptr;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
cutlass::FastDivmod head_divmod, m_block_divmod;
int total_blocks;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
uint32_t scale_softmax_log2_half2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
int *__restrict__ blockmask;
// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;
// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
int * __restrict__ cache_batch_idx;
// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;
// Local window size
int window_size_left, window_size_right;
// Random state.
at::PhiloxCudaState philox_args;
// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;
bool is_bf16;
bool is_e4m3;
bool is_causal;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
bool is_rotary_interleaved;
int num_splits; // For split-KV version
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
int * __restrict__ tile_count_semaphore;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_bwd_params : public Flash_fwd_params {
// The dO and dQKV matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;
// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
int *__restrict__ dq_semaphore;
bool deterministic;
index_t dq_accum_split_stride;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/numeric_types.h>
#include "flash.h"
#include "static_switch.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
void set_params_fprop(Flash_fwd_params &params,
// sizes
const size_t b,
const size_t seqlen_q,
const size_t seqlen_k,
const size_t seqlen_q_rounded,
const size_t seqlen_k_rounded,
const size_t h,
const size_t h_k,
const size_t d,
const size_t d_rounded,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
at::Tensor out,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *seqused_k,
void *p_d,
void *softmax_lse_d,
float p_dropout,
float softmax_scale,
int window_size_left,
int window_size_right,
bool seqlenq_ngroups_swapped=false) {
// Reset the parameters
params = {};
params.is_bf16 = q.dtype() == torch::kBFloat16;
params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
params.v_ptr = v.data_ptr();
// All stride are in elements, not bytes.
params.q_row_stride = q.stride(-3);
params.k_row_stride = k.stride(-3);
params.v_row_stride = v.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = k.stride(-2);
params.v_head_stride = v.stride(-2);
params.o_ptr = out.data_ptr();
params.o_row_stride = out.stride(-3);
params.o_head_stride = out.stride(-2);
if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = q.stride(0);
params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0);
params.o_batch_stride = out.stride(0);
if (seqlenq_ngroups_swapped) {
params.q_batch_stride *= seqlen_q;
params.o_batch_stride *= seqlen_q;
}
}
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
params.seqused_k = static_cast<int *>(seqused_k);
// P = softmax(QK^T)
params.p_ptr = p_d;
// Softmax sum
params.softmax_lse_ptr = softmax_lse_d;
// Set the dimensions.
params.b = b;
params.h = h;
params.h_k = h_k;
params.h_h_k_ratio = h / h_k;
params.seqlen_q = seqlen_q;
params.seqlen_k = seqlen_k;
params.seqlen_q_rounded = seqlen_q_rounded;
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
params.head_divmod = cutlass::FastDivmod(int(h));
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
__half scale_softmax_log2_half = __float2half(params.scale_softmax_log2);
__half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half);
params.scale_softmax_log2_half2 = reinterpret_cast<uint32_t&>(scale_softmax_log2_half2);
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
// Convert p from float to int so we don't have to convert the random uint to float to compare.
// [Minor] We want to round down since when we do the comparison we use <= instead of <
// params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
// params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
TORCH_CHECK(p_dropout < 1.f);
#ifdef FLASHATTENTION_DISABLE_DROPOUT
TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
#endif
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
params.is_causal = window_size_left < 0 && window_size_right == 0;
if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;
#ifdef FLASHATTENTION_DISABLE_LOCAL
TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
"This flash attention build does not support local attention.");
#endif
params.is_seqlens_k_cumulative = true;
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
#endif
}
void set_params_dgrad(Flash_bwd_params &params,
// sizes
const size_t b,
const size_t seqlen_q,
const size_t seqlen_k,
const size_t seqlen_q_rounded,
const size_t seqlen_k_rounded,
const size_t h,
const size_t h_k,
const size_t d,
const size_t d_rounded,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor out,
const at::Tensor dout,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *dq_accum_d,
void *dk_accum_d,
void *dv_accum_d,
void *softmax_lse_d,
void *dsoftmax_sum_d,
float p_dropout,
float softmax_scale,
int window_size_left,
int window_size_right,
bool deterministic) {
set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
q, k, v, out,
cu_seqlens_q_d,
cu_seqlens_k_d,
nullptr,
nullptr,
softmax_lse_d,
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
// Set the pointers and strides.
params.do_ptr = dout.data_ptr();
params.do_row_stride = dout.stride(-3);
params.do_head_stride = dout.stride(-2);
params.dq_ptr = dq.data_ptr();
params.dk_ptr = dk.data_ptr();
params.dv_ptr = dv.data_ptr();
params.dq_row_stride = dq.stride(-3);
params.dk_row_stride = dk.stride(-3);
params.dv_row_stride = dv.stride(-3);
params.dq_head_stride = dq.stride(-2);
params.dk_head_stride = dk.stride(-2);
params.dv_head_stride = dv.stride(-2);
if (cu_seqlens_q_d == nullptr) {
params.do_batch_stride = dout.stride(0);
params.dq_batch_stride = dq.stride(0);
params.dk_batch_stride = dk.stride(0);
params.dv_batch_stride = dv.stride(0);
}
params.dq_accum_ptr = dq_accum_d;
params.dk_accum_ptr = dk_accum_d;
params.dv_accum_ptr = dv_accum_d;
// Softmax sum
params.dsoftmax_sum = dsoftmax_sum_d;
params.deterministic = deterministic;
}
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
// HEADDIM_SWITCH(params.d, [&] {
// run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
// });
if (!params.is_e4m3) {
if (params.d == 64) {
run_mha_fwd_<cutlass::half_t, 64>(params, stream);
} else if (params.d == 128) {
run_mha_fwd_<cutlass::half_t, 128>(params, stream);
} else {
run_mha_fwd_<cutlass::half_t, 256>(params, stream);
}
} else {
// run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
}
}
std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
auto q_dtype = q.dtype();
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
TORCH_CHECK(q_dtype == torch::kFloat16,
"FlashAttention only support fp16 data type for now");
// TODO: will add e4m3 later
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
// "FlashAttention only support fp16 and bf16 data type");
// "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(q.is_contiguous(), "Input tensor must be contiguous");
TORCH_CHECK(k.is_contiguous(), "Input tensor must be contiguous");
TORCH_CHECK(v.is_contiguous(), "Input tensor must be contiguous");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads == num_heads_k, "We do not support MQA/GQA yet");
TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) {
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
q_padded = q;
k_padded = k;
v_padded = v;
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor p;
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q_padded, k_padded, v_padded, out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
nullptr,
softmax_lse.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
/*window_size_left=*/-1,
/*window_size_right=*/is_causal ? 0 : -1);
auto tile_count_semaphore = is_causal ? torch::full({1}, 132, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
at::Tensor out_padded = out;
if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
}
void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
// FP16_SWITCH(!params.is_bf16, [&] {
// HEADDIM_SWITCH(params.d, [&] {
// run_mha_bwd_<elem_type, kHeadDim>(params, stream);
// });
// });
if (params.d == 64) {
run_mha_bwd_<cutlass::half_t, 64>(params, stream);
} else if (params.d == 128) {
run_mha_bwd_<cutlass::half_t, 128>(params, stream);
} else {
run_mha_bwd_<cutlass::half_t, 256>(params, stream);
}
}
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
const float softmax_scale,
const bool is_causal) {
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
#endif
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16,
// "FlashAttention only support fp16 and bf16 data type");
"FlashAttention only support fp16 data type for now");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
TORCH_CHECK(q.is_contiguous(), "Input tensor must be contiguous");
TORCH_CHECK(k.is_contiguous(), "Input tensor must be contiguous");
TORCH_CHECK(v.is_contiguous(), "Input tensor must be contiguous");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
const int num_heads = sizes[2];
const int head_size_og = dout.size(3);
const int head_size = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(head_size_og == 64 || head_size_og == 128, "Only support head size 64 and 128 for now");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
at::Tensor dq, dk, dv;
if (dq_.has_value()) {
dq = dq_.value();
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
CHECK_DEVICE(dq);
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
} else {
dq = torch::empty_like(q);
}
if (dk_.has_value()) {
dk = dk_.value();
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
CHECK_DEVICE(dk);
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
} else {
dk = torch::empty_like(k);
}
if (dv_.has_value()) {
dv = dv_.value();
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
CHECK_DEVICE(dv);
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
} else {
dv = torch::empty_like(v);
}
at::Tensor dout_padded;
if (head_size_og % 8 != 0) {
dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
dout_padded = dout;
}
// bool loop = seqlen_k > blocksize_c;
// TODO: change later, for now set to true for simplicity
bool loop = true;
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
at::Tensor dq_accum;
at::Tensor dk_accum, dv_accum;
if (loop) {
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
// dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
// dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
}
at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
} else {
dk_expanded = dk;
dv_expanded = dv;
}
Flash_bwd_params params;
set_params_dgrad(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, k, v, out,
dout_padded, dq, dk_expanded, dv_expanded,
nullptr,
nullptr,
loop ? dq_accum.data_ptr() : nullptr,
// loop ? dk_accum.data_ptr() : nullptr,
// loop ? dv_accum.data_ptr() : nullptr,
nullptr,
nullptr,
softmax_lse.data_ptr(),
softmax_d.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
/*window_size_left=*/-1,
/*window_size_right=*/-1,
/*deterministic=*/false);
at::Tensor dq_semaphore = torch::zeros({(seqlen_q + 64 - 1) / 64, batch_size, num_heads}, opts.dtype(torch::kInt32));
params.dq_semaphore = dq_semaphore.data_ptr<int>();
// printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads);
auto launch = &run_mha_bwd;
if (seqlen_q > 0) {
launch(params, stream);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
if (head_size_og % 8 != 0) {
dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
}
return { dq, dk, dv, softmax_d };
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
}
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Union
import torch
import torch.nn as nn
# isort: off
# We need to import the CUDA kernels after importing torch
import flashattn_hopper_cuda
# isort: on
def _flash_attn_forward(q, k, v, softmax_scale, causal):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
q,
k,
v,
None,
softmax_scale,
causal,
)
return out, q, k, v, out_padded, softmax_lse, S_dmask
def _flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
causal
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
causal,
)
return dq, dk, dv, softmax_d
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
softmax_scale,
causal,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
q,
k,
v,
softmax_scale,
causal
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out, softmax_lse
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
ctx.softmax_scale,
ctx.causal,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None
def flash_attn_func(
q,
k,
v,
softmax_scale=None,
causal=False,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(
q,
k,
v,
softmax_scale,
causal,
)
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t>(params, stream);
}
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/tensor.hpp"
#include <cutlass/cutlass.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/barrier.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "flash.h"
#include "utils.h"
#include "softmax.h"
namespace flash {
using namespace cute;
template <typename Ktraits, bool Is_causal, typename TiledCopyQ, typename TiledCopydO,
typename TiledCopyK, typename TiledCopyV, typename TiledCopydK, typename TiledCopydV>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_dqkv(CUTE_GRID_CONSTANT Flash_bwd_params const params,
CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q,
CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO,
CUTE_GRID_CONSTANT TiledCopyK const tma_load_K,
CUTE_GRID_CONSTANT TiledCopyV const tma_load_V,
CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK,
CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using index_t = typename Ktraits::index_t;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kNThreads = Ktraits::kNThreads;
// static constexpr int NumMmaThreads = size(typename Ktraits::TiledMmaSdP{});
static constexpr int NumMmaThreads = Ktraits::kNThreads;
static constexpr int kBlockM = Ktraits::kBlockM;
// static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kStages = Ktraits::kStages;
static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB;
static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB;
static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB;
static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS;
if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); }
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const n_block = blockIdx.x;
int const bidb = blockIdx.z; // The block index for the batch.
int const bidh = blockIdx.y; // The block index for the head.
int lane_predicate = cute::elect_one_sync();
int warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dK.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dV.get_tma_descriptor());
}
Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dsoftmax_sum)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{}));
Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dq_accum_ptr)),
make_shape(params.seqlen_q, params.d, params.h, params.b),
make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_q_rounded));
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
// if (cute::thread0()) { print(tma_load_K); printf("\n"); }
// if (cute::thread0()) { print(mK); printf("\n"); print(gK); printf("\n"); }
typename Ktraits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
// Construct SMEM tensors.
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{});
Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{});
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{});
Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{});
Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{});
Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{});
Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{});
Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{});
Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{});
// Prepare the TMA loads
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto block_tma_Q = tma_load_Q.get_slice(cluster_local_block_id.y);
auto block_tma_dO = tma_load_dO.get_slice(cluster_local_block_id.y);
auto block_tma_K = tma_load_K.get_slice(_0{});
auto block_tma_V = tma_load_V.get_slice(_0{});
Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K, k)
Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K, PIPE)
Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K, k)
Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K, PIPE)
Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K)
Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K)
Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K)
Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K)
// if (cute::thread0()) { print(tQgQ); printf("\n"); print(tQsQ); printf("\n"); }
// if (cute::thread0()) { print(tKgK); printf("\n"); print(tKsK); printf("\n"); }
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesdO = static_cast<uint32_t>(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO);
constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesK == TmaTransactionBytesV);
// Obtain warp index
int thread_idx = int(threadIdx.x);
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
// int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = TmaTransactionBytesQ;
pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_K.init(1 /*numThreads*/);
shared_storage.barrier_V.init(1 /*numThreads*/);
}
// cutlass::arch::fence_barrier_init();
// We're counting on pipeline_q to call fence_barrier_init();
MainloopPipeline pipeline_q(shared_storage.pipeline_q, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_do(shared_storage.pipeline_do, pipeline_params, ClusterShape{});
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
PipelineState smem_pipe_read_q, smem_pipe_read_do;
PipelineState smem_pipe_release_q, smem_pipe_release_do;
PipelineState smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline>();
// Copy K tile and V tile from GMEM to SMEM.
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_K.arrive_and_expect_tx(TmaTransactionBytesK);
copy(tma_load_K.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_K), 0 /*mcast_mask*/), tKgK, tKsK);
shared_storage.barrier_V.arrive_and_expect_tx(TmaTransactionBytesV);
copy(tma_load_V.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_V), 0 /*mcast_mask*/), tVgV, tVsV);
}
// if (cute::thread0()) { print_tensor(sQ); printf("\n"); } __syncthreads();
int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1;
uint16_t mcast_mask_qdo = 0;
if constexpr (cute::is_same_v<typename Ktraits::GmemTiledCopyQdO, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{}));
}
}
// Issue TmaLoads (Prologue fetches)
if (warp_idx == 0 && lane_predicate) {
// Issue the prologue loads
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kStages && stage <= m_block; ++stage) {
pipeline_q.producer_acquire(smem_pipe_write_q);
copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - stage), tQsQ(_, _, _, stage));
++smem_pipe_write_q;
pipeline_do.producer_acquire(smem_pipe_write_do);
copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - stage), tdOsdO(_, _, _, stage));
++smem_pipe_write_do;
}
}
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
// Initialize matmul objects.
typename Ktraits::TiledMmaSdP tiledMmaSdP;
auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x);
typename Ktraits::TiledMmadKV tiledMmadKV;
auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x);
typename Ktraits::TiledMmadQ tiledMmadQ;
auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x);
// Allocate accumulator
Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP);
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x);
if constexpr (!SdP_swapAB) {
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_B(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV);
Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{});
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
clear(tdKrdK);
clear(tdVrdV);
shared_storage.barrier_K.wait(0);
shared_storage.barrier_V.wait(0);
__syncthreads();
// #pragma unroll 2
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block >= 0; --m_block) {
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_q.consumer_wait(smem_pipe_read_q);
__syncwarp();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrK, tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_do.consumer_wait(smem_pipe_read_do);
__syncwarp();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrV, tdPrdP);
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout()));
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// if (cute::thread0()) { print_tensor(scores); printf("\n"); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
int const warp_group_idx = cutlass::canonical_warp_group_idx();
cutlass::arch::NamedBarrier::arrive(kNThreads, warp_group_idx /*id*/);
warpgroup_wait<0>();
// Reshape tdPrdP from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
}
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
// cutlass::arch::NamedBarrier::arrive(kNThreads, 1 /*id*/);
cutlass::arch::NamedBarrier::arrive(kNThreads, 2 + warp_group_idx /*id*/);
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
if (m_block > 0) {
gLSE.data() = gLSE.data() + (-int(kBlockM));
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = gLSE(row);
dP_sum(mi) = gdPsum(row);
}
}
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
if constexpr (Mma_dQ_is_RS) {
static_assert(!dQ_swapAB);
Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadQ>(tdPrdP.layout()));
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
// if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); }
}
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
// if (cute::thread0()) { print_tensor(sK); printf("\n"); }
// if (cute::thread0()) { print_tensor(sKt); printf("\n"); } __syncthreads();
// __syncthreads(); // Without this I'm getting race condition, I thought the barrier would be enough
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(kNThreads, 1 - warp_group_idx /*id*/);
if constexpr (!dKV_swapAB) {
Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV);
} else {
Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrP, tdVrdV);
}
++smem_pipe_read_do;
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); }
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(kNThreads, 2 + 1 - warp_group_idx /*id*/);
if constexpr (!Mma_dQ_is_RS) {
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
}
}
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); }
if constexpr (!dKV_swapAB) {
Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK);
} else {
Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdS, tdKrdK);
}
++smem_pipe_read_q;
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); }
warpgroup_wait<Mma_dQ_is_RS ? 2 : 1>();
// if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); }
Tensor tdQrdQ_atomic = recast<float4>(tdQrdQ);
Tensor tdQgdQaccum_atomic = recast<float4>(tdQgdQaccum(_, _, _, m_block));
#pragma unroll
for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
// for (int i = 0; i < size(tdQrdQ_atomic); ++i) { tdQgdQaccum_atomic(i) = tdQrdQ_atomic(i); }
warpgroup_wait<0>();
pipeline_do.consumer_release(smem_pipe_release_do); // release V
++smem_pipe_release_do;
pipeline_q.consumer_release(smem_pipe_release_q); // release V
++smem_pipe_release_q;
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate && m_block >= kStages) {
pipeline_q.producer_acquire(smem_pipe_write_q);
copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - kStages), tQsQ(_, _, _, smem_pipe_write_q.index()));
++smem_pipe_write_q;
pipeline_do.producer_acquire(smem_pipe_write_do);
copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - kStages), tdOsdO(_, _, _, smem_pipe_write_do.index()));
++smem_pipe_write_do;
}
}
} else { // SdP_swapAB
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_A(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV);
Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _);
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<1>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// cute::fill(lse, 1);
// cute::fill(dP_sum, 1);
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
clear(tdKrdK);
clear(tdVrdV);
shared_storage.barrier_K.wait(0);
shared_storage.barrier_V.wait(0);
// #pragma unroll 2
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block >= 0; --m_block) {
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_q.consumer_wait(smem_pipe_read_q);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrK, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_do.consumer_wait(smem_pipe_read_do);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrdP);
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// if (cute::thread0()) { print_tensor(scores); printf("\n"); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
static_assert(!dKV_swapAB);
Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tSrS.layout()));
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV);
++smem_pipe_read_do;
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); }
warpgroup_wait<1>();
// Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
}
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
if (m_block > 0) {
gLSE.data() = gLSE.data() + (-int(kBlockM));
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<1>(taccScS_row(mi));
lse(mi) = gLSE(row);
dP_sum(mi) = gdPsum(row);
}
}
Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tdPrdP.layout()));
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK);
++smem_pipe_read_q;
// warpgroup_wait<0>();
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); }
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
// cutlass::arch::NamedBarrier::sync(kNThreads, 0 /*id*/);
__syncthreads();
static_assert(!Mma_dQ_is_RS);
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
}
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
warpgroup_wait<0>();
// if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); }
Tensor tdQrdQ_atomic = recast<float4>(tdQrdQ);
Tensor tdQgdQaccum_atomic = recast<float4>(tdQgdQaccum(_, _, _, m_block));
#pragma unroll
for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
// for (int i = 0; i < size(tdQrdQ_atomic); ++i) { tdQgdQaccum_atomic(i) = tdQrdQ_atomic(i); }
pipeline_do.consumer_release(smem_pipe_release_do); // release V
++smem_pipe_release_do;
pipeline_q.consumer_release(smem_pipe_release_q); // release V
++smem_pipe_release_q;
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate && m_block >= kStages) {
pipeline_q.producer_acquire(smem_pipe_write_q);
copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - kStages), tQsQ(_, _, _, smem_pipe_write_q.index()));
++smem_pipe_write_q;
pipeline_do.producer_acquire(smem_pipe_write_do);
copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - kStages), tdOsdO(_, _, _, smem_pipe_write_do.index()));
++smem_pipe_write_do;
}
}
}
// Epilogue
#pragma unroll
for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.scale_softmax; }
Tensor tdKrdK_out = convert_type<Element>(tdKrdK);
Tensor tdVrdV_out = convert_type<Element>(tdVrdV);
Tensor sdK = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdK{});
Tensor sdV = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdV{});
Tensor sdKt = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdKt{});
Tensor sdVt = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdVt{});
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdKV{}, tiledMmadKV);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(threadIdx.x);
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
__syncthreads();
if constexpr (!dKV_swapAB) {
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
} else {
Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt);
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt);
}
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
Tensor mdK = tma_store_dK.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mdV = tma_store_dV.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
auto block_tma_dK = tma_store_dK.get_slice(_0{});
auto block_tma_dV = tma_store_dV.get_slice(_0{});
Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
__syncthreads(); // ensure all threads have issued their async fence
lane_predicate = cute::elect_one_sync();
warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
cute::copy(tma_store_dV, tdVsdV, tdVgdV);
cute::copy(tma_store_dK, tdKsdK, tdKgdK);
tma_store_arrive();
}
tma_store_wait<0>();
// To make sure remote SMEM doesn't get destroyed
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive();
cute::cluster_wait();
}
}
template <typename Ktraits, bool Is_causal, typename TiledCopyQ, typename TiledCopydO,
typename TiledCopyK, typename TiledCopyV, typename TiledCopydQ, typename TiledCopydK, typename TiledCopydV>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_dqkv_seqqpar(CUTE_GRID_CONSTANT Flash_bwd_params const params,
CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q,
CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO,
CUTE_GRID_CONSTANT TiledCopyK const tma_load_K,
CUTE_GRID_CONSTANT TiledCopyV const tma_load_V,
CUTE_GRID_CONSTANT TiledCopydQ const tma_store_dQ,
CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK,
CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using index_t = typename Ktraits::index_t;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kNThreads = Ktraits::kNThreads;
static constexpr int NumMmaThreads = Ktraits::kNThreads;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kStages = Ktraits::kStages;
static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB;
static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB;
static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB;
static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS;
if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); }
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const m_block = blockIdx.x;
int const bidb = blockIdx.z; // The block index for the batch.
int const bidh = blockIdx.y; // The block index for the head.
int lane_predicate = cute::elect_one_sync();
int warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dQ.get_tma_descriptor());
}
Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dsoftmax_sum)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{}));
Tensor mdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dk_accum_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_k));
Tensor mdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dv_accum_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_k));
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor gdKaccum = local_tile(mdKaccum(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor gdVaccum = local_tile(mdVaccum(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
typename Ktraits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(threadIdx.x);
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
// Construct SMEM tensors.
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{});
Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{});
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{});
Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{});
Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{});
Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{});
Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{});
Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{});
Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{});
// Prepare the TMA loads
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto block_tma_Q = tma_load_Q.get_slice(_0{});
auto block_tma_dO = tma_load_dO.get_slice(_0{});
auto block_tma_K = tma_load_K.get_slice(cluster_local_block_id.x);
auto block_tma_V = tma_load_V.get_slice(cluster_local_block_id.x);
Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K)
Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K)
Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K)
Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K)
Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K, k)
Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K, PIPE)
Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K, k)
Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K, PIPE)
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesdO = static_cast<uint32_t>(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO);
constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesK == TmaTransactionBytesV);
// Obtain warp index
int thread_idx = int(threadIdx.x);
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
// int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = TmaTransactionBytesK;
pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1 /*numThreads*/);
shared_storage.barrier_dO.init(1 /*numThreads*/);
}
// cutlass::arch::fence_barrier_init();
// We're counting on pipeline_k to call fence_barrier_init();
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
PipelineState smem_pipe_read_k, smem_pipe_read_v;
PipelineState smem_pipe_release_k, smem_pipe_release_v;
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
// Copy K tile and V tile from GMEM to SMEM.
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
shared_storage.barrier_dO.arrive_and_expect_tx(TmaTransactionBytesdO);
copy(tma_load_dO.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_dO), 0 /*mcast_mask*/), tdOgdO, tdOsdO);
}
// if (cute::thread0()) { print_tensor(sQ); printf("\n"); } __syncthreads();
int n_block = cute::ceil_div(params.seqlen_k, kBlockN) - 1;
uint16_t mcast_mask_kv = 0;
if constexpr (cute::is_same_v<typename Ktraits::GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
}
}
// Issue TmaLoads (Prologue fetches)
if (warp_idx == 0 && lane_predicate) {
// Issue the prologue loads
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kStages && stage <= n_block; ++stage) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - stage), tKsK(_, _, _, stage));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - stage), tVsV(_, _, _, stage));
++smem_pipe_write_v;
}
}
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
// Initialize matmul objects.
typename Ktraits::TiledMmaSdP tiledMmaSdP;
auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x);
typename Ktraits::TiledMmadKV tiledMmadKV;
auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x);
typename Ktraits::TiledMmadQ tiledMmadQ;
auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x);
// Allocate accumulator
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
clear(tdQrdQ);
auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP);
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x);
if constexpr (!SdP_swapAB) {
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_B(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV);
Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{});
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// if (cute::thread0()) { print_tensor(lse); printf("\n"); }
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
shared_storage.barrier_Q.wait(0);
shared_storage.barrier_dO.wait(0);
// #pragma unroll 2
CUTLASS_PRAGMA_NO_UNROLL
for (; n_block >= 0; --n_block) {
// Otherwise we might have WG0 still wating on NamedBarrier but WG1 already
// started the next iteration and start flipping the same NamedBarrier.
__syncthreads();
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_k.consumer_wait(smem_pipe_read_k);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_v.consumer_wait(smem_pipe_read_v);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrdO, tdPrV(_, _, _, smem_pipe_read_v.index()), tdPrdP);
++smem_pipe_read_v;
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout()));
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// if (cute::thread0()) { print_tensor(scores); printf("\n"); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
int const warp_group_idx = cutlass::canonical_warp_group_idx();
cutlass::arch::NamedBarrier::arrive(kNThreads, warp_group_idx /*id*/);
warpgroup_wait<0>();
// Reshape tdPrdP from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
}
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
cutlass::arch::NamedBarrier::arrive(kNThreads, 2 + warp_group_idx /*id*/);
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
if constexpr (Mma_dQ_is_RS) {
static_assert(!dQ_swapAB);
Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadQ>(tdPrdP.layout()));
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ);
// if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); }
}
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
// if (cute::thread0()) { print_tensor(sK); printf("\n"); }
// if (cute::thread0()) { print_tensor(sKt); printf("\n"); } __syncthreads();
// if (cute::thread0()) { printf("before barrier sync 0\n"); }
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(kNThreads, 1 - warp_group_idx /*id*/);
// if (cute::thread0()) { printf("after barrier sync 0\n"); }
Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
if constexpr (!dKV_swapAB) {
Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO, tdVrdV);
} else {
Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdVrdO, tdVrP, tdVrdV);
}
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); }
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(kNThreads, 2 + 1 - warp_group_idx /*id*/);
if constexpr (!Mma_dQ_is_RS) {
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadQ, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdS, tdQrdQ);
}
}
++smem_pipe_read_k;
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); }
Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
if constexpr (!dKV_swapAB) {
Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdKrdS, tdKrQ, tdKrdK);
} else {
Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdKrQ, tdKrdS, tdKrdK);
}
// warpgroup_wait<0>();
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); }
warpgroup_wait<Mma_dQ_is_RS ? 1 : 2>();
// if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); }
Tensor tdVrdV_atomic = recast<float4>(tdVrdV);
Tensor tdVgdVaccum_atomic = recast<float4>(tdVgdVaccum(_, _, _, n_block));
#pragma unroll
for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdVaccum_atomic(i), tdVrdV_atomic(i)); }
// for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); }
warpgroup_wait<0>();
Tensor tdKrdK_atomic = recast<float4>(tdKrdK);
Tensor tdKgdKaccum_atomic = recast<float4>(tdKgdKaccum(_, _, _, n_block));
#pragma unroll
for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdKaccum_atomic(i), tdKrdK_atomic(i)); }
pipeline_v.consumer_release(smem_pipe_release_v); // release V
++smem_pipe_release_v;
pipeline_k.consumer_release(smem_pipe_release_k); // release V
++smem_pipe_release_k;
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate && n_block >= kStages) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - kStages), tKsK(_, _, _, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - kStages), tVsV(_, _, _, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
} else { // SdP_swapAB
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_A(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV);
Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _);
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<1>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// if (cute::thread0()) { print_tensor(taccScS_row); printf("\n"); }
// cute::fill(lse, 1);
// cute::fill(dP_sum, 1);
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
clear(tdQrdQ);
shared_storage.barrier_Q.wait(0);
shared_storage.barrier_dO.wait(0);
// #pragma unroll 2
CUTLASS_PRAGMA_NO_UNROLL
for (; n_block >= 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_k.consumer_wait(smem_pipe_read_k);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrK(_, _, _, smem_pipe_read_k.index()), tSrQ, tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_v.consumer_wait(smem_pipe_read_v);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrV(_, _, _, smem_pipe_read_v.index()), tdPrdO, tdPrdP);
++smem_pipe_read_v;
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
// if (cute::thread0()) { print_tensor(lse); printf("\n"); }
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// if (cute::thread0()) { print_tensor(scores); printf("\n"); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
static_assert(!dKV_swapAB);
Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select<1, 2>(TileShape_MNK{}));
Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tSrS.layout()));
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO, tdVrdV);
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); }
warpgroup_wait<1>();
// Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
}
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select<1, 2>(TileShape_MNK{}));
Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tdPrdP.layout()));
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdKrdS, tdKrQ, tdKrdK);
// warpgroup_wait<0>();
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); }
warpgroup_wait<1>();
// if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); }
Tensor tdVrdV_atomic = recast<float4>(tdVrdV);
Tensor tdVgdVaccum_atomic = recast<float4>(tdVgdVaccum(_, _, _, n_block));
#pragma unroll
for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdVaccum_atomic(i), tdVrdV_atomic(i)); }
// for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); }
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
// cutlass::arch::NamedBarrier::sync(kNThreads, 0 /*id*/);
__syncthreads();
static_assert(!Mma_dQ_is_RS);
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadQ, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdS, tdQrdQ);
}
++smem_pipe_read_k;
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
warpgroup_wait<1>();
// if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); }
Tensor tdKrdK_atomic = recast<float4>(tdKrdK);
Tensor tdKgdKaccum_atomic = recast<float4>(tdKgdKaccum(_, _, _, n_block));
#pragma unroll
for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdKaccum_atomic(i), tdKrdK_atomic(i)); }
// for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); }
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_release_v); // release V
++smem_pipe_release_v;
pipeline_k.consumer_release(smem_pipe_release_k); // release V
++smem_pipe_release_k;
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate && n_block >= kStages) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - kStages), tKsK(_, _, _, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - kStages), tVsV(_, _, _, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
}
// Epilogue
#pragma unroll
for (int i = 0; i < size(tdQrdQ); ++i) { tdQrdQ(i) *= params.scale_softmax; }
// if (cute::thread0()) { print_tensor(tdQrdQ); }
Tensor tdQrdQ_out = convert_type<Element>(tdQrdQ);
Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), typename Ktraits::SmemLayoutdQ{});
Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), typename Ktraits::SmemLayoutdQt{});
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdQ{}, tiledMmadQ);
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(threadIdx.x);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(tdQrdQ_out); // ((Atom,AtomNum), MMA_M, MMA_N)
__syncthreads();
if constexpr (!dQ_swapAB) {
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
} else {
Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt);
}
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
Tensor mdQ = tma_store_dQ.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor gdQ = local_tile(mdQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_dQ = tma_store_dQ.get_slice(_0{});
Tensor tdQgdQ = block_tma_dQ.partition_D(gdQ); // (TMA, TMA_M, TMA_K)
Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K)
__syncthreads(); // ensure all threads have issued their async fence
// if (cute::thread0()) { print_tensor(sdQ); }
lane_predicate = cute::elect_one_sync();
warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
cute::copy(tma_store_dQ, tdQsdQ, tdQgdQ);
tma_store_arrive();
}
tma_store_wait<0>();
// To make sure remote SMEM doesn't get destroyed
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive();
cute::cluster_wait();
}
}
template <typename Ktraits, bool Is_causal, typename TiledCopyQ, typename TiledCopydO,
typename TiledCopyK, typename TiledCopyV, typename TiledCopydK, typename TiledCopydV, typename TiledCopydQ, typename TiledCopyAdddQ>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_dqkv_ws(CUTE_GRID_CONSTANT Flash_bwd_params const params,
CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q,
CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO,
CUTE_GRID_CONSTANT TiledCopyK const tma_load_K,
CUTE_GRID_CONSTANT TiledCopyV const tma_load_V,
CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK,
CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV,
CUTE_GRID_CONSTANT TiledCopydQ const tma_store_dQ,
CUTE_GRID_CONSTANT TiledCopyAdddQ const tma_reduce_add_dQ) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using index_t = typename Ktraits::index_t;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static_assert(Ktraits::Is_WS);
// static constexpr int kNThreads = Ktraits::kNThreads;
// static constexpr int NumMmaThreads = size(typename Ktraits::TiledMmaSdP{});
static constexpr int NumMmaThreads = 256;
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kNThreadsdQ = Ktraits::kNThreadsdQ;
// static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim;
// static constexpr int kStages = Ktraits::kStages;
static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB;
static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB;
static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB;
if constexpr (SdP_swapAB) { static_assert(!dKV_swapAB); }
static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS;
if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); }
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int lane_predicate = cute::elect_one_sync();
int warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dK.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dV.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dQ.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_reduce_add_dQ.get_tma_descriptor());
}
// Construct SMEM tensors.
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{});
Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{});
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{});
Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{});
Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{});
Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{});
Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{});
Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{});
Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{});
Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacc{});
Tensor sdQ2 = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacc2{});
Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacct{});
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesdO = static_cast<uint32_t>(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO);
constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesK == TmaTransactionBytesV);
// Obtain warp index
int thread_idx = int(threadIdx.x);
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
// int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = TmaTransactionBytesQ;
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 0) {
pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
} else {
pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_K.init(1 /*numThreads*/);
shared_storage.barrier_V.init(1 /*numThreads*/);
}
// cutlass::arch::fence_barrier_init();
// We're counting on pipeline_q to call fence_barrier_init();
MainloopPipeline pipeline_q(shared_storage.pipeline_q, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_do(shared_storage.pipeline_do, pipeline_params, ClusterShape{});
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
if (warp_group_idx == 0) { // Producer
// method in cutlass/arch/reg_reconfig.h
// calls setmaxnreg.dec.sync.aligned.u32
cutlass::arch::warpgroup_reg_dealloc<24>();
int const n_block = blockIdx.x;
int const bidb = blockIdx.z; // The block index for the batch.
int const bidh = blockIdx.y; // The block index for the head.
int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1;
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
int lane_predicate = cute::elect_one_sync();
// if (warp_idx_in_warpgroup == 0 && lane_predicate) {
if (warp_idx_in_warpgroup == 0) { // Load K, and do TMA on Q and dO
Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
// Prepare the TMA loads
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto block_tma_Q = tma_load_Q.get_slice(cluster_local_block_id.y);
auto block_tma_dO = tma_load_dO.get_slice(cluster_local_block_id.y);
auto block_tma_K = tma_load_K.get_slice(_0{});
Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K, k)
Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K, PIPE)
Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K, k)
Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K, PIPE)
Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K)
Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K)
PipelineState smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline>();
uint16_t mcast_mask_qdo = 0;
if constexpr (cute::is_same_v<typename Ktraits::GmemTiledCopyQdO, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{}));
}
}
if (lane_predicate) {
// Copy K tile and V tile from GMEM to SMEM.
shared_storage.barrier_K.arrive_and_expect_tx(TmaTransactionBytesK);
copy(tma_load_K.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_K), 0 /*mcast_mask*/), tKgK, tKsK);
#pragma unroll 2
for (; m_block >= 0; --m_block) {
pipeline_q.producer_acquire(smem_pipe_write_q);
copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block), tQsQ(_, _, _, smem_pipe_write_q.index()));
++smem_pipe_write_q;
pipeline_do.producer_acquire(smem_pipe_write_do);
copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block), tdOsdO(_, _, _, smem_pipe_write_do.index()));
++smem_pipe_write_do;
}
// Tail loop
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline_q.producer_tail(smem_pipe_write_q);
pipeline_do.producer_tail(smem_pipe_write_do);
}
} else if (warp_idx_in_warpgroup == 1) { // Load V, and do TMA_REDUCE_ADD on dQ
Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
auto block_tma_V = tma_load_V.get_slice(_0{});
Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K)
Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K)
if (lane_predicate) {
shared_storage.barrier_V.arrive_and_expect_tx(TmaTransactionBytesV);
copy(tma_load_V.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_V), 0 /*mcast_mask*/), tVgV, tVsV);
}
Tensor mdQaccum = tma_store_dQ.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
auto block_tma_dQ = tma_store_dQ.get_slice(_0{});
Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K)
Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K)
Tensor tdQsdQ2 = block_tma_dQ.partition_S(sdQ2); // (TMA, TMA_M, TMA_K, 2)
int *lock_ptr = params.dq_semaphore + bidb * params.h + bidh;
using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
// cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 1 /*id*/); // sdQ empty, ready to be written to
cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to
// cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + (m_block + 1) % 2 /*id*/); // sdQ empty, ready to be written to
// cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to
// if (n_block == 0) { // Use TMA_STORE
if (false) { // Use TMA_STORE
#pragma unroll 2
for (; m_block >= 0; --m_block) {
cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 /*id*/); // sdQ full, to be written to gmem
// cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 + m_block % 2 /*id*/); // sdQ full, to be written to gmem
if (lane_predicate) {
cute::copy(tma_store_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block));
// cute::copy(tma_store_dQ, tdQsdQ2(_, _, _, m_block % 2), tdQgdQ(_, _, _, m_block));
tma_store_arrive();
}
tma_store_wait<0>();
Barrier::arrive_inc(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h);
cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to
// cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to
}
} else { // Use TMA_REDUCE_ADD
#pragma unroll 2
for (; m_block >= 0; --m_block) {
// Barrier::wait_eq(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h, n_block);
// Barrier::wait_lt(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h, 1);
cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 /*id*/); // sdQ full, to be written to gmem
// cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 + m_block % 2 /*id*/); // sdQ full, to be written to gmem
if (lane_predicate) {
cute::copy(tma_reduce_add_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block));
// cute::copy(tma_reduce_add_dQ, tdQsdQ2(_, _, _, m_block % 2), tdQgdQ(_, _, _, m_block));
tma_store_arrive();
}
tma_store_wait<0>();
// Barrier::arrive_inc(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h);
// cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to
cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to
}
}
// } else if (warp_idx_in_warpgroup == 2) { // Load LSE and dPSum
// Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
// make_shape(params.b, params.h, params.seqlen_q),
// make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
// Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dsoftmax_sum)),
// make_shape(params.b, params.h, params.seqlen_q),
// make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
// Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(_)); // (M, _)
// Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(_)); // (M, _)
// Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape<Int<kBlockM>>{});
// Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.smem_dpsum.data()), Shape<Int<kBlockM>>{});
// #pragma unroll 2
// for (; m_block >= 0; --m_block) {
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 3 /*id*/); // sLSE and sdPsum are empty
// #pragma unroll
// for (int i = 0; i < cute::ceil_div(kBlockM, 32); ++i) {
// int idx = threadIdx.x % 32 + i * 32;
// sLSE(idx) = idx < params.seqlen_q - m_block * kBlockM ? gLSE(idx, m_block) : INFINITY;
// sdPsum(idx) = idx < params.seqlen_q - m_block * kBlockM ? gdPsum(idx, m_block) : 0;
// }
// // sLSE and sdPsum are ready for WG 1
// cutlass::arch::NamedBarrier::arrive(128 + 32, 3 + 1 /*id*/);
// // sLSE and sdPsum are ready for WG 2
// cutlass::arch::NamedBarrier::arrive(128 + 32, 3 + 2 /*id*/);
// }
}
} else { // Consumers
// method in cutlass/arch/reg_reconfig.h
// calls setmaxnreg.inc.sync.aligned.u32
cutlass::arch::warpgroup_reg_alloc<240>();
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
PipelineState smem_pipe_read_q, smem_pipe_read_do;
PipelineState smem_pipe_release_q, smem_pipe_release_do;
int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1;
const int m_block_max = m_block;
int bidb = blockIdx.z; // The block index for the batch.
int bidh = blockIdx.y; // The block index for the head.
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dsoftmax_sum)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{}));
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape<Int<kBlockM>>{});
Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.smem_dpsum.data()), Shape<Int<kBlockM>>{});
typename Ktraits::RmemTiledCopydQacc rmem_tiled_copy_dQaccum;
// auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice((threadIdx.x - NumCopyThreads) % kNThreadsdQ);
auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x - NumCopyThreads);
Tensor tdQsdQaccum = rmem_thr_copy_dQaccum.partition_D(sdQ);
Tensor tdQsdQaccum2 = rmem_thr_copy_dQaccum.partition_D(sdQ2);
// Initialize matmul objects.
typename Ktraits::TiledMmaSdP tiledMmaSdP;
auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x - NumCopyThreads);
typename Ktraits::TiledMmadKV tiledMmadKV;
auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x - NumCopyThreads);
typename Ktraits::TiledMmadQ tiledMmadQ;
// auto threadMmadQ = tiledMmadQ.get_thread_slice((threadIdx.x - NumCopyThreads) % kNThreadsdQ);
auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x - NumCopyThreads);
// Allocate accumulator
Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP);
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x - NumCopyThreads);
// auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdQ{}, tiledMmadQ);
// auto smem_tiled_copy_dQ = make_tiled_copy_C(Copy_Atom<cute::SM90_U32x4_STSM_N, ElementAccum>{}, tiledMmadQ);
// auto smem_tiled_copy_dQ = make_tiled_copy_C(Copy_Atom<DefaultCopy, ElementAccum>{}, tiledMmadQ);
// auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(threadIdx.x - NumCopyThreads);
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
if constexpr (SdP_swapAB) {
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_A(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV);
Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _);
static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(taccScS_row))::value, 8);
static constexpr bool kStatsDivisibleBy8 = decltype(size(taccScS_row))::value % 8 == 0;
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
// Tensor lse = make_tensor<ElementAccum>(Shape<Int<kStatsPerThread>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<1>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// #pragma unroll
// for (int mi = 0; mi < size(lse); ++mi) {
// const int row_idx = mi * 8 + (threadIdx.x % 32) / 4;
// const int row = kStatsDivisibleBy8 || row_idx < size(taccScS_row) ? get<1>(taccScS_row(row_idx)) : 0;
// lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
// dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
// }
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); }
// Trying to spread LSE and dPSum across threads in a warp but it's slower
// const int row_idx = mi * 8 + (threadIdx.x % 32) / 4;
// const int row = get<1>(taccScS_row(row_idx)); // TODO: what if row_idx is outside the range?
// cute::fill(lse, 1);
// cute::fill(dP_sum, 1);
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 3 /*id*/); // sLSE and sdPsum are empty
clear(tdKrdK);
clear(tdVrdV);
shared_storage.barrier_K.wait(0);
shared_storage.barrier_V.wait(0);
// #pragma unroll 2
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block >= 0; --m_block) {
// Putting this dQ block at the beginning of the loop gives an extra 10 TFLOPs
// It does make the code uglier, idk if it's worth it.
if (m_block < m_block_max) {
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
// dS is already written to smem, and the smem for dQ is empty (from warp 1 doing TMA_REDUCE_ADD)
// int warp_group_idx = cutlass::canonical_warp_group_idx();
// if (warp_group_idx == 1 + (m_block + 1) % 2) {
// // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/);
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4);
// } else {
// // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/);
// cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4);
// static_assert(!Mma_dQ_is_RS);
// Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
// if constexpr (!dQ_swapAB) {
// Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
// Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
// flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
// } else {
// Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
// Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
// flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
// }
// Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
// cutlass::arch::NamedBarrier::sync(NumMmaThreads / 2 + 32, 0 + (m_block + 1) % 2 /*id*/);
// cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2));
// cutlass::arch::fence_view_async_shared();
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads / 2 + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem
// }
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/);
cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 /*id*/);
static_assert(!Mma_dQ_is_RS);
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
// Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
}
// Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt);
// Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 1 /*id*/); // sdQ empty, ready to be written to
cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum);
// cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2));
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 /*id*/); // sdQ ready to be written to gmem
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem
}
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_q.consumer_wait(smem_pipe_read_q);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrK, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_do.consumer_wait(smem_pipe_read_do);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrdP);
// sLSE and sdPsum are done loading for WG 1 or 2
// cutlass::arch::NamedBarrier::sync(128 + 32, 3 + cutlass::canonical_warp_group_idx() /*id*/);
// Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
// #pragma unroll
// for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = sLSE(get<1>(taccScS_row(mi))); }
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// #pragma unroll
// for (int mi = 0; mi < size<0>(lse); ++mi) { lse(mi) *= float(M_LOG2E); }
// #pragma unroll
// for (int mi = 0; mi < size<0>(scores); ++mi) {
// // const float lse_scaled = lse(mi) * float(M_LOG2E);
// const float lse_scaled = __shfl_sync(0xffffffff, lse(mi / 8), (mi % 8) * 4 + (threadIdx.x % 4));
// // const float lse_scaled = __shfl_xor_sync(0xffffffff, lse(mi / 8), 1 << (mi % 4)) * float(M_LOG2E);
// // const float lse_scaled = lse(mi);
// #pragma unroll
// for (int ni = 0; ni < size<1>(scores); ++ni) {
// scores(mi, ni) = exp2f(scores(mi, ni) * params.scale_softmax_log2 - lse_scaled);
// }
// }
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(scores); printf("\n"); }
// Tensor dP_sum = make_fragment_like(lse);
// sLSE and sdPsum are done loading for WG 1 or 2
// cutlass::arch::NamedBarrier::sync(128 + 32, 3 + cutlass::canonical_warp_group_idx() /*id*/);
// #pragma unroll
// for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<1>(taccScS_row(mi))); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
warpgroup_wait<0>();
// Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); }
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); }
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
// const float dP_sum_cur = __shfl_sync(0xffffffff, dP_sum(mi / 8), (mi % 8) * 4 + (threadIdx.x % 4));
// const float dP_sum_cur = __shfl_xor_sync(0xffffffff, dP_sum(mi / 8), 1 << (mi % 4));
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
// for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); }
}
// sLSE and sdPsum are done processing, can load for the next iteration
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 3 /*id*/);
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); }
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
if (m_block > 0) {
gLSE.data() = gLSE.data() + (-int(kBlockM));
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
}
// #pragma unroll
// for (int mi = 0; mi < size(lse); ++mi) {
// // const int row = get<1>(taccScS_row(mi));
// const int row_idx = mi * 8 + (threadIdx.x % 32) / 4;
// const int row = kStatsDivisibleBy8 || row_idx < size(taccScS_row) ? get<1>(taccScS_row(row_idx)) : 0;
// lse(mi) = gLSE(row);
// dP_sum(mi) = gdPsum(row);
// }
Tensor lse_float2 = recast<float2>(lse);
Tensor dP_sum_float2 = recast<float2>(dP_sum);
#pragma unroll
for (int mi = 0; mi < size(lse) / 2; ++mi) {
const int row = get<1>(taccScS_row(mi * 2));
lse_float2(mi) = *reinterpret_cast<float2*>(&(gLSE(row)));
dP_sum_float2(mi) = *reinterpret_cast<float2*>(&(gdPsum(row)));
}
static_assert(!dKV_swapAB);
Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tSrS.layout()));
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV);
++smem_pipe_read_do;
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dV_tmp); printf("\n"); }
Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tdPrdP.layout()));
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK);
++smem_pipe_read_q;
// warpgroup_wait<0>();
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dK_tmp); printf("\n"); }
pipeline_do.consumer_release(smem_pipe_release_do); // release V
++smem_pipe_release_do;
pipeline_q.consumer_release(smem_pipe_release_q); // release V
++smem_pipe_release_q;
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
}
{
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
// dS is already written to smem, and the smem for dQ is empty (from warp 1 doing TMA_REDUCE_ADD)
// int warp_group_idx = cutlass::canonical_warp_group_idx();
// if (warp_group_idx == 1 + (m_block + 1) % 2) {
// // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/);
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4);
// } else {
// // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/);
// cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4);
// static_assert(!Mma_dQ_is_RS);
// Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
// if constexpr (!dQ_swapAB) {
// Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
// Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
// flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
// } else {
// Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
// Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
// flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
// }
// Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
// cutlass::arch::NamedBarrier::sync(NumMmaThreads / 2 + 32, 0 + (m_block + 1) % 2 /*id*/);
// cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2));
// cutlass::arch::fence_view_async_shared();
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads / 2 + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); }
// // if (blockIdx.x == 0 && threadIdx.x == 128) { print(taccdQrdQ); printf("\n"); print(tdQsdQaccum2); printf("\n"); }
// }
cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 /*id*/);
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + 0 % 2 /*id*/);
// cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/);
static_assert(!Mma_dQ_is_RS);
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
}
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); }
Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 1 /*id*/); // sdQ empty, ready to be written to
cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum);
// cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, 0 % 2));
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 /*id*/); // sdQ ready to be written to gmem
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 + 0 % 2 /*id*/); // sdQ ready to be written to gmem
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(sdQ); printf("\n"); }
}
} else { // !SdP_swapAB
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_B(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV);
Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{});
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
clear(tdKrdK);
clear(tdVrdV);
shared_storage.barrier_K.wait(0);
shared_storage.barrier_V.wait(0);
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block >= 0; --m_block) {
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_q.consumer_wait(smem_pipe_read_q);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrK, tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_do.consumer_wait(smem_pipe_read_do);
// if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { printf("After dO wait\n"); }
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrV, tdPrdP);
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout()));
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { print_tensor(scores); printf("\n"); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N)
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 8 /*id*/);
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
int const warp_group_idx = cutlass::canonical_warp_group_idx() - 1;
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4 + warp_group_idx /*id*/);
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier arrive 4, tidx = %d\n", threadIdx.x); }
warpgroup_wait<0>();
// Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); }
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); }
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
}
// if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); }
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 6 + warp_group_idx /*id*/);
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier arrive 6, tidx = %d\n", threadIdx.x); }
if (m_block > 0) {
gLSE.data() = gLSE.data() + (-int(kBlockM));
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
}
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<1>(taccScS_row(mi));
lse(mi) = gLSE(row);
dP_sum(mi) = gdPsum(row);
}
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
if constexpr (Mma_dQ_is_RS) {
static_assert(!dQ_swapAB);
Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadQ>(tdPrdP.layout()));
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
// if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); }
}
cutlass::arch::fence_view_async_shared();
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("Before barrier sync 4, tidx = %d\n", threadIdx.x); }
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4 + 1 - warp_group_idx /*id*/);
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier sync 4, tidx = %d\n", threadIdx.x); }
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128)) { print_tensor(sPt); printf("\n"); }
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128)) { print_tensor(sdOt); printf("\n"); }
if constexpr (!dKV_swapAB) {
Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV);
} else {
Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrP, tdVrdV);
}
++smem_pipe_read_do;
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dV_tmp); printf("\n"); }
cutlass::arch::fence_view_async_shared();
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("Before barrier sync 6, tidx = %d\n", threadIdx.x); }
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 6 + 1 - warp_group_idx /*id*/);
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier sync 6, tidx = %d\n", threadIdx.x); }
if constexpr (!dKV_swapAB) {
Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK);
} else {
Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdS, tdKrdK);
}
++smem_pipe_read_q;
warpgroup_wait<0>();
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dK_tmp); printf("\n"); }
pipeline_do.consumer_release(smem_pipe_release_do); // release V
++smem_pipe_release_do;
pipeline_q.consumer_release(smem_pipe_release_q); // release V
++smem_pipe_release_q;
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 8 /*id*/);
}
}
// Epilogue
Tensor sdK = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdK{});
Tensor sdV = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdV{});
Tensor sdKt = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdKt{});
Tensor sdVt = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdVt{});
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdKV{}, tiledMmadKV);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(threadIdx.x - NumCopyThreads);
int n_block = blockIdx.x;
bidb = blockIdx.z; // The block index for the batch.
bidh = blockIdx.y; // The block index for the head.
Tensor mdK = tma_store_dK.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mdV = tma_store_dV.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
auto block_tma_dK = tma_store_dK.get_slice(_0{});
auto block_tma_dV = tma_store_dV.get_slice(_0{});
Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
// Very slightly faster to do the smem write and TMA write for dV first, then do the same for dK,
// Instead of doing both at the same time.
Tensor tdVrdV_out = convert_type<Element>(tdVrdV);
#pragma unroll
for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.scale_softmax; }
Tensor tdKrdK_out = convert_type<Element>(tdKrdK);
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
// Can't use __syncthreads() in WS code
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(NumMmaThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
synchronize();
if constexpr (!dKV_swapAB) {
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
} else {
Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt);
}
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
synchronize();
lane_predicate = cute::elect_one_sync();
warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == NumCopyThreads / cutlass::NumThreadsPerWarp && lane_predicate) {
cute::copy(tma_store_dV, tdVsdV, tdVgdV);
tma_store_arrive();
}
if constexpr (!dKV_swapAB) {
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
} else {
Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt);
}
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
synchronize();
if (warp_idx == NumCopyThreads / cutlass::NumThreadsPerWarp && lane_predicate) {
cute::copy(tma_store_dK, tdKsdK, tdKgdK);
tma_store_arrive();
}
tma_store_wait<0>();
}
}
} // namespace flash
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "cute/tensor.hpp"
#include "cutlass/cluster_launch.hpp"
#include "static_switch.h"
#include "flash.h"
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h"
#include "kernel_traits.h"
template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
}
// template<typename Kernel_traits>
// __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
// flash::convert_dQ<Kernel_traits>(params, nsplits);
// }
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
flash::convert_dKV<Kernel_traits>(params);
}
template<typename Kernel_traits, bool Is_causal>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
dim3 grid_m(num_m_block, params.b, params.h);
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
// If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum
// flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)),
make_shape(params.seqlen_q, params.d, params.h, params.b),
make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride));
auto tma_load_Q = make_tma_copy(
typename Kernel_traits::GmemTiledCopyQdO{},
mQ,
typename Kernel_traits::SmemLayoutQ{}(_, _, _0{}),
// typename Kernel_traits::SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
Tensor mdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.do_ptr)),
make_shape(params.seqlen_q, params.d, params.h, params.b),
make_stride(params.do_row_stride, _1{}, params.do_head_stride, params.do_batch_stride));
auto tma_load_dO = make_tma_copy(
typename Kernel_traits::GmemTiledCopyQdO{},
mdO,
typename Kernel_traits::SmemLayoutdO{}(_, _, _0{}),
// typename Kernel_traits::SmemLayoutdO{},
select<0, 2>(TileShape_MNK{}),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride));
auto tma_load_K = make_tma_copy(
typename Kernel_traits::GmemTiledCopyKV{},
mK,
typename Kernel_traits::SmemLayoutK{},
// typename Kernel_traits::SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
_1{}); // no mcast for K
Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride));
auto tma_load_V = make_tma_copy(
typename Kernel_traits::GmemTiledCopyKV{},
mV,
typename Kernel_traits::SmemLayoutV{},
// typename Kernel_traits::SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
_1{}); // no mcast for V
Tensor mdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.dk_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride));
auto tma_store_dK = make_tma_copy(
typename Kernel_traits::GmemTiledCopydKV{},
mdK,
typename Kernel_traits::SmemLayoutdK{},
select<1, 2>(TileShape_MNK{}),
_1{}); // no mcast for output
Tensor mdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.dv_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride));
auto tma_store_dV = make_tma_copy(
typename Kernel_traits::GmemTiledCopydKV{},
mdV,
typename Kernel_traits::SmemLayoutdV{},
select<1, 2>(TileShape_MNK{}),
_1{}); // no mcast for output
Tensor mdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.dq_ptr)),
make_shape(params.seqlen_q, params.d, params.h, params.b),
make_stride(params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride));
Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dq_accum_ptr)),
make_shape(params.seqlen_q, params.d, params.h, params.b),
make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_q_rounded));
auto tma_store_dQaccum = make_tma_copy(
// typename Kernel_traits::GmemTiledCopydKV{},
typename cute::SM90_TMA_STORE{},
// mdQ,
mdQaccum,
// typename Kernel_traits::SmemLayoutdQTMA{},
typename Kernel_traits::SmemLayoutdQaccTMA{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for output
auto tma_reduce_add_dQaccum = make_tma_copy(
// typename Kernel_traits::GmemTiledCopydKV{},
typename cute::SM90_TMA_REDUCE_ADD{},
// mdQ,
mdQaccum,
// typename Kernel_traits::SmemLayoutdQTMA{},
typename Kernel_traits::SmemLayoutdQaccTMA{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for output
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
// print(typename Kernel_traits::TiledMmaSdP{}); printf("\n");
// print(typename Kernel_traits::TiledMmadKV{}); printf("\n");
// print(typename Kernel_traits::TiledMmadQ{}); printf("\n");
// print(typename Kernel_traits::SmemLayoutAtomK{}); printf("\n");
// print(typename Kernel_traits::SmemLayoutK{}); printf("\n");
// print(typename Kernel_traits::SmemLayoutKt{}); printf("\n");
// Get the ptr to kernel function.
void *kernel;
if constexpr (!Kernel_traits::Is_WS) {
kernel = (void *)flash::compute_dqkv<Kernel_traits, Is_causal, decltype(tma_load_Q), decltype(tma_load_dO),
decltype(tma_load_K), decltype(tma_load_V), decltype(tma_store_dK), decltype(tma_store_dV)>;
} else {
kernel = (void *)flash::compute_dqkv_ws<Kernel_traits, Is_causal, decltype(tma_load_Q), decltype(tma_load_dO),
decltype(tma_load_K), decltype(tma_load_V), decltype(tma_store_dK), decltype(tma_store_dV), decltype(tma_store_dQaccum), decltype(tma_reduce_add_dQaccum)>;
}
// void *kernel = (void *)flash::compute_dqkv_seqqpar<Kernel_traits, Is_causal, decltype(tma_load_Q), decltype(tma_load_dO),
// decltype(tma_load_K), decltype(tma_load_V), decltype(tma_store_dQaccum), decltype(tma_store_dK), decltype(tma_store_dV)>;
auto shared_storage = typename Kernel_traits::SharedStorage{};
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
int smem_size_q = sizeof(decltype(shared_storage.smem_q));
int smem_size_do = sizeof(decltype(shared_storage.smem_do));
int smem_size_k = sizeof(decltype(shared_storage.smem_k));
int smem_size_v = sizeof(decltype(shared_storage.smem_v));
// int smem_size_p = sizeof(decltype(shared_storage.smem_p));
int smem_size_ds = sizeof(decltype(shared_storage.smem_ds));
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds);
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds);
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
int num_blocks_n = cutlass::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
num_blocks_n = cutlass::ceil_div(num_blocks_n, size<1>(ClusterShape{})) * size<1>(ClusterShape{});
dim3 grid_dims(num_blocks_n, params.h, params.b);
// int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
// num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
// dim3 grid_dims(num_blocks_m, params.h, params.b);
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
if constexpr (!Kernel_traits::Is_WS) {
cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
tma_load_K, tma_load_V, tma_store_dK, tma_store_dV);
} else {
cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
tma_load_K, tma_load_V, tma_store_dK, tma_store_dV, tma_store_dQaccum, tma_reduce_add_dQaccum);
}
// cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
// tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto tma_load_dQaccum = make_tma_copy(
typename cute::SM90_TMA_LOAD{},
mdQaccum,
typename Kernel_traits::SmemLayoutdQaccTMA{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for output
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
auto kernel_dq = &flash::convert_dQ<Kernel_traits, decltype(tma_load_dQaccum)>;
if (Kernel_traits::kSmemdQSize * 2 + 8 >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize * 2 + 8));
}
kernel_dq<<<grid_m, Kernel_traits::kNThreadsdQ, Kernel_traits::kSmemdQSize * 2 + 8, stream>>>(params, tma_load_dQaccum);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
// if (Kernel_traits::kSmemdKVSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize));
// }
// int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
// dim3 grid_n(num_n_block, params.b, params.h);
// kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemdKVSize, stream>>>(params);
// C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename T>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// run_flash_bwd<T, Headdim, Is_causal>(params, stream);
// });
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, false, false, false, 2, 2, 2, 1, T>, false>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 12, true, false, false, 1, 2, 2, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 96, 128, 12, true, false, true, 1, 2, 2, 1, T>, false>(params, stream);
}
template<typename T>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// run_flash_bwd<T, Headdim, Is_causal>(params, stream);
// });
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, 2, 1, 2, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, false, false, false, 1, 2, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 96, 8, false, true, false, 2, 1, 2, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 96, 8, false, true, true, 2, 1, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 12, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 12, true, false, false, 1, 2, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 12, false, false, false, 1, 2, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 80, 128, 12, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_seqqpar_kernel_traits<Headdim, 128, 64, 8, false, true, false, 2, 1, 2, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_seqqpar_kernel_traits<Headdim, 96, 128, 8, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
}
template<typename T>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
// constexpr static int Headdim = 256;
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// run_flash_bwd<T, Headdim, Is_causal>(params, stream);
// });
}
/***************************************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "flash.h"
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
// Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
// The last coordinate is the "page".
Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
make_layout(get<0>(do_.layout()),
get<2>(do_.layout()))));
Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
Tensor do_fp32 = flash::convert_type<float>(do_reshaped);
Tensor o_fp32 = flash::convert_type<float>(o_reshaped);
#pragma unroll
for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
#pragma unroll
for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
}
flash::SumOp<float> sum_op;
dP_sum_cur = flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
if (threadIdx.x % THREADS_PER_ROW == 0) {
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
inline __device__ void compute_dot_do_o(const Params &params) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{}));
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
// TODO: careful, we're zeroing out dQaccum with type float4, but when
// we do atomicAdds, we use type float. The layouts are different. Check this.
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
// Allocate predicate tensors for k
Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
// Set predicates for k bounds
#pragma unroll
for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}
Tensor tdOrdO = make_fragment_like(tdOgdO);
Tensor tdOrO = make_fragment_like(tdOgO);
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
);
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
);
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
// so that (dP - dP_sum) is on the same scale.
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
// Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
Kernel_traits::kNThreadsNonWS / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
if (Clear_dQaccum) {
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
// do atomicAdds on.
Tensor zero = make_fragment_like(tdQgdQaccum);
clear(zero);
cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Params>
inline __device__ void clear_dKVaccum(const Params &params) {
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int n_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
Tensor zero = make_fragment_like(tdKgdKaccum);
clear(zero);
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert dQ from dQaccum (in float) to fp16/bf16.
// This is used in the case where we want to parallelize the backward across seqlen_k.
// template<typename Kernel_traits, typename Params, typename TiledCopydQaccum>
template<typename Kernel_traits, typename TiledCopydQaccum>
// inline __device__ void convert_dQ(const Params &params,
__global__ void convert_dQ(CUTE_GRID_CONSTANT Flash_bwd_params const params,
CUTE_GRID_CONSTANT TiledCopydQaccum const tma_load_dQaccum) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
int lane_predicate = cute::elect_one_sync();
int warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
cute::prefetch_tma_descriptor(tma_load_dQaccum.get_tma_descriptor());
}
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
static constexpr bool dQ_swapAB = Kernel_traits::dQ_swapAB;
Tensor mdQaccum = tma_load_dQaccum.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), Shape<Int<kBlockM>, Int<kHeadDim>>{}, make_coord(m_block, _0{})); // (M, K)
const BlockInfo binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
// Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
// Shape<Int<kBlockM>, Int<kHeadDim>>{},
// make_stride(params.h * params.d_rounded, _1{}));
Tensor sdQTMA = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)),
typename Kernel_traits::SmemLayoutdQaccTMA{});
Tensor sdQaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)),
typename Kernel_traits::SmemLayoutdQacc{});
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdQ{});
Tensor sdQt = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdQt{});
auto &barrier_dQaccum = *reinterpret_cast<cutlass::arch::ClusterTransactionBarrier*>(smem_ + sizeof(ElementAccum) * size(sdQTMA));
auto block_tma_dQ = tma_load_dQaccum.get_slice(_0{});
Tensor tdQgdQaccumTMA = block_tma_dQ.partition_S(gdQaccum); // (TMA, TMA_M, TMA_K)
Tensor tdQsdQaccumTMA = block_tma_dQ.partition_D(sdQTMA); // (TMA, TMA_M, TMA_K)
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
// typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
// typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
// auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
// Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size<0>(sdQTMA) * size<1>(sdQTMA) * cutlass::sizeof_bits_v<ElementAccum> / 8);
if (warp_idx == 0 && lane_predicate) {
barrier_dQaccum.init(1 /*numThreads*/);
}
__syncthreads();
if (warp_idx == 0 && lane_predicate) {
barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
copy(tma_load_dQaccum.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(barrier_dQaccum), 0 /*mcast_mask*/), tdQgdQaccumTMA, tdQsdQaccumTMA);
}
barrier_dQaccum.wait(0);
// if (cute::thread0()) { print_tensor(sdQTMA); printf("\n"); }
typename Kernel_traits::RmemTiledCopydQacc rmem_tiled_copy_dQaccum;
auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x);
Tensor tdQsdQaccum = rmem_thr_copy_dQaccum.partition_S(sdQaccum);
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<!dQ_swapAB ? kBlockM : kHeadDim>, Int<!dQ_swapAB ? kHeadDim : kBlockM>>{}); // MMA, MMA_N, MMA_K
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQsdQaccum));
Tensor tdQrdQaccum = rmem_thr_copy_dQaccum.retile_D(acc_dq);
cute::copy(rmem_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
// Tensor dQ_tmp = make_tensor(acc_dq.data(), flash::convert_layout_acc_rowcol(acc_dq.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 0) { print_tensor(dQ_tmp); printf("\n"); }
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
// Convert acc_dq from fp32 to fp16
Tensor rdQ = flash::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
// dQacc and dQ uses the same shared memory, need to wait for all threads to finish reading smem first
__syncthreads();
if constexpr (!dQ_swapAB) {
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
} else {
Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt);
}
__syncthreads();
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
#pragma unroll
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
// This is used in the case where we want to parallelize the backward across seqlen_q.
template<typename Kernel_traits, typename Params>
inline __device__ void convert_dKV(const Params &params) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
const int n_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
static constexpr bool dKV_swapAB = Kernel_traits::dKV_swapAB;
const BlockInfo binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
+ n_block * kBlockN) * params.d_rounded;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdKV{});
Tensor sdKt = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdKVt{});
Tensor sdV = make_tensor(sdK.data() + size(sdK),
typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
Tensor sdVt = make_tensor(make_smem_ptr(sdK.data() + size(sdK)),
typename Kernel_traits::SmemLayoutdKVt{});
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
// typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<!dKV_swapAB ? kBlockN : kHeadDim>, Int<!dKV_swapAB ? kHeadDim : kBlockN>>{}); // MMA, MMA_N, MMA_K
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<!dKV_swapAB ? kBlockN : kHeadDim>, Int<!dKV_swapAB ? kHeadDim : kBlockN>>{}); // MMA, MMA_N, MMA_K
CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
#pragma unroll
for (int i = 0; i < size(acc_dk); ++i) {
acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
}
#pragma unroll
for (int i = 0; i < size(acc_dv); ++i) {
acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
}
// Convert acc_dk from fp32 to fp16
Tensor rdK = flash::convert_type<Element>(acc_dk);
Tensor rdV = flash::convert_type<Element>(acc_dv);
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
if constexpr (!dKV_swapAB) {
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
} else {
Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt);
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt);
}
__syncthreads();
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
// if (cute::thread0()) { print_tensor(tdKrdK); printf("\n"); }
// if (cute::thread0()) { print_tensor(tdVrdV); printf("\n"); }
Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
} // namespace flash
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
}
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/tensor.hpp"
#include <cutlass/cutlass.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "flash.h"
#include "utils.h"
#include "softmax.h"
#include "tile_scheduler.hpp"
#include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
#include "epilogue_fwd_sm90_tma.hpp"
namespace flash {
using namespace cute;
template <typename Ktraits, bool Is_causal, typename TileScheduler>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_attn_ws(CUTE_GRID_CONSTANT Flash_fwd_params const params,
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal>::Params const mainloop_params,
CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits>::Params const epilogue_params,
CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params
) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static_assert(Ktraits::Is_WS);
static constexpr bool Is_WS = Ktraits::Is_WS;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
// static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim;
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal>;
using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1 /*numThreads*/);
shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
}
// We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue;
// We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 24 : 32>();
// cutlass::arch::warpgroup_reg_dealloc<56>();
// StaticPersistentTileScheduler scheduler{params.m_block_divmod, params.head_divmod, params.total_blocks};
// auto work_tile_info = scheduler.get_current_work();
TileScheduler scheduler;
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
int work_idx = 0;
// auto get_tile_count = [&] () {
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
// return shared_storage.tile_count_semaphore;
// };
// while (work_tile_info.is_valid()) {
// for (int tile_count = blockIdx.x; tile_count < params.total_blocks; tile_count = get_tile_count()) {
// for (int tile_count_semaphore = blockIdx.x; tile_count_semaphore < params.total_blocks; tile_count_semaphore = __shfl_sync(0xffffffff, tile_count_semaphore, 0)) {
for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) {
int tile_count_semaphore = 0;
collective_mainloop.load(params, mainloop_params, scheduler_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
shared_storage, work_tile_info, work_idx, tile_count_semaphore);
// ++work_idx;
// work_tile_info = scheduler.fetch_next_work();
}
collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
}
} else { // Consumer
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 240 : 160>();
// cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 224 : 160>();
// Initialize matmul objects.
typename Ktraits::TiledMma1 tiled_mma1;
TileScheduler scheduler{};
PipelineState smem_pipe_read_k, smem_pipe_read_v;
// We don't need separate variables smem_pip_release_k and smem_pipe_release_v
// (like in Cutlass's gemm) because the read and release pipeline states are always the same.
auto get_tile_count = [&] () {
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
return shared_storage.tile_count_semaphore;
};
collective_mainloop.mma_init();
int work_idx = 0;
CUTLASS_PRAGMA_NO_UNROLL
// for (int work_idx = 0; work_idx * gridDim.x + blockIdx.x < params.total_blocks; ++work_idx) {
// for (int tile_count_semaphore = blockIdx.x, work_idx = 0; tile_count_semaphore < params.total_blocks; tile_count_semaphore = get_tile_count()) {
for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) {
// Attention output (GEMM-II) accumulator.
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
// int m_block;
// int bidh, bidb;
// // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, work_idx * gridDim.x + blockIdx.x));
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
// cute::tuple<int32_t, int32_t, int32_t> block_coord = {m_block, bidh, bidb};
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord;
int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block);
if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
// Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/);
collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord);
continue;
}
collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage);
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, 0, shared_storage);
collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
threadIdx.x - NumCopyThreads, block_coord);
++work_idx;
// work_tile_info = scheduler.fetch_next_work();
}
collective_epilogue.store_tail();
}
}
} // namespace flash
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "cute/tensor.hpp"
#include "cutlass/cluster_launch.hpp"
#include "static_switch.h"
#include "flash.h"
#include "tile_scheduler.hpp"
#include "flash_fwd_kernel.h"
#include "kernel_traits.h"
template<typename Kernel_traits, bool Is_causal>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
using Element = typename Kernel_traits::Element;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal>;
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits>;
// using Scheduler = flash::SingleTileScheduler;
using Scheduler = flash::StaticPersistentTileScheduler;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(params.q_ptr),
{params.seqlen_q, params.d, params.h, params.b}, // shape_Q
{params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride}, // stride_Q
static_cast<Element const*>(params.k_ptr),
{params.seqlen_k, params.d, params.h_k, params.b}, // shape_K
{params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride}, // stride_K
static_cast<Element const*>(params.v_ptr),
{params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride}, // stride_V
params.scale_softmax_log2
});
typename CollectiveEpilogue::Params epilogue_params =
CollectiveEpilogue::to_underlying_arguments({
static_cast<Element*>(params.o_ptr),
{params.seqlen_q, params.d, params.h, params.b}, // shape_O
{params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride}, // stride_O
static_cast<float*>(params.softmax_lse_ptr),
{_1{}, params.seqlen_q, params.h * params.seqlen_q}, // stride_LSE
});
int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b};
typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
// Get the ptr to kernel function.
void *kernel;
kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
params.m_block_divmod = cutlass::FastDivmod(num_blocks_m);
params.total_blocks = num_blocks_m * params.h * params.b;
// dim3 grid_dims(num_blocks_m, params.h, params.b);
// dim3 grid_dims(132);
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, 132);
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, params, mainloop_params, epilogue_params, scheduler_params);
// kernel<<<grid_dims, block_dims, smem_size, stream>>>(params, tma_load_Q, tma_load_K, tma_load_V, tma_store_O);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>, Is_causal>(params, stream);
});
}
template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal ? 2 : 1, T>, Is_causal>(params, stream);
});
}
template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal ? 2 : 1, T>, Is_causal>(params, stream);
});
}
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
using namespace cute;
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
struct SharedStorageQKVO {
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
union {
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct {
cutlass::arch::ClusterTransactionBarrier barrier_Q;
cutlass::arch::ClusterBarrier barrier_O;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
int tile_count_semaphore;
};
};
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
int kClusterM_ = 1, typename elem_type=cutlass::half_t>
struct Flash_fwd_kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using index_t = int64_t;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
static constexpr bool Is_WS = kNWarps_ >= 12;
static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers");
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
static constexpr int kClusterM = kClusterM_;
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
static constexpr int kStages = kStages_;
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma0 = decltype(cute::make_tiled_mma(
std::conditional_t<
Is_Q_in_regs,
decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),
decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
>{},
AtomLayoutMNK{}));
using TiledMma1 = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutMNK{}));
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ,
SmemLayoutK, SmemLayoutV, SmemLayoutO>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
// using BarrierType = typename MainloopPipeline::ProducerBarrierType;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
class SmemLayoutdK, class SmemLayoutdV>
struct SharedStorageQKVdOdKV;
template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
class SmemLayoutdK, class SmemLayoutdV>
struct SharedStorageQKVdOdKV<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
union {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
};
};
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
};
struct {
cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
cutlass::arch::ClusterTransactionBarrier barrier_K;
cutlass::arch::ClusterTransactionBarrier barrier_V;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
};
};
template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
class SmemLayoutdK, class SmemLayoutdV>
struct SharedStorageQKVdOdKV<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
union {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
};
};
union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
};
};
struct {
cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
cutlass::arch::ClusterTransactionBarrier barrier_K;
cutlass::arch::ClusterTransactionBarrier barrier_V;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
};
};
template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
class SmemLayoutdK, class SmemLayoutdV>
struct SharedStorageQKVdOdKVWS;
template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
class SmemLayoutdK, class SmemLayoutdV>
struct SharedStorageQKVdOdKVWS<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
union {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
};
};
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
cute::array_aligned<float, 128> smem_lse;
cute::array_aligned<float, 128> smem_dpsum;
};
struct {
cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
cutlass::arch::ClusterTransactionBarrier barrier_K;
cutlass::arch::ClusterTransactionBarrier barrier_V;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
};
};
template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
class SmemLayoutdK, class SmemLayoutdV>
struct SharedStorageQKVdOdKVWS<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
union {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
};
};
union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
};
cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
cute::array_aligned<float, 128> smem_lse;
cute::array_aligned<float, 128> smem_dpsum;
};
struct {
cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
cutlass::arch::ClusterTransactionBarrier barrier_K;
cutlass::arch::ClusterTransactionBarrier barrier_V;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
};
};
template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
class SmemLayoutdQ>
struct SharedStorageQKVdOdKVSeqqPar;
template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
class SmemLayoutdQ>
struct SharedStorageQKVdOdKVSeqqPar<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
union {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
};
};
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
};
struct {
cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
cutlass::arch::ClusterTransactionBarrier barrier_Q;
cutlass::arch::ClusterTransactionBarrier barrier_dO;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
};
};
template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
class SmemLayoutdQ>
struct SharedStorageQKVdOdKVSeqqPar<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
union {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
};
};
union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
};
};
struct {
cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
cutlass::arch::ClusterTransactionBarrier barrier_Q;
cutlass::arch::ClusterTransactionBarrier barrier_dO;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
int kClusterN_ = 1, typename elem_type=cutlass::half_t>
struct Flash_bwd_kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using index_t = int64_t;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp;
// static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup;
static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup;
static_assert(kNWarps_ == 8 || kNWarps_ == 12);
static constexpr bool Is_WS = kNWarps_ >= 12;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
static constexpr int kClusterN = kClusterN_;
using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
static constexpr int kStages = 2;
static constexpr bool SdP_swapAB = SdP_swapAB_;
static constexpr bool dKV_swapAB = dKV_swapAB_;
static constexpr bool dQ_swapAB = dQ_swapAB_;
static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
using TileShapeAtomSdP = std::conditional_t<
!SdP_swapAB,
Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
>;
using AtomLayoutSdP = std::conditional_t<
!SdP_swapAB,
Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
>;
using TiledMmaSdP = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
AtomLayoutSdP{}));
using TileShapeAtomdKV = std::conditional_t<
!dKV_swapAB,
Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
>;
using AtomLayoutdKV = std::conditional_t<
!dKV_swapAB,
Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
>;
using TiledMmadKV = decltype(cute::make_tiled_mma(
std::conditional_t<
!SdP_swapAB,
decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
>{},
AtomLayoutdKV{}));
using TileShapeAtomdQ = std::conditional_t<
!dQ_swapAB,
Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
// Shape<Int<kBlockM>, Int<kHeadDim >, Int<kBlockN>>,
// Shape<Int<kHeadDim>, Int<kBlockM>, Int<kBlockN>>
>;
using AtomLayoutdQ = std::conditional_t<
!dQ_swapAB,
Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
// Layout<Shape<Int<1>, Int<1>, _1>>,
// Layout<Shape<Int<1>, Int<1>, _1>>
>;
static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
using TiledMmadQ = decltype(cute::make_tiled_mma(
std::conditional_t<
!dQ_swapAB,
std::conditional_t<
Mma_dQ_is_RS,
decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
>,
decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
>{},
AtomLayoutdQ{}));
using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
using GmemTiledCopydKV = cute::SM90_TMA_STORE;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static constexpr bool Has_cp_async = true;
#else
static constexpr bool Has_cp_async = false;
#endif
// For the dot_do_o preprocessing kernel
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
// to affect speed in practice.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<kNThreadsNonWS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemLayoutAtomdQ = Layout<Shape <Int<kNThreadsdQ / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopydO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtomdQ{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomdQaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <Int<kNThreadsdQ / 8>, _8>, // Thread layout, 8 threads per row
Stride< _8, _1>>,
Layout<Shape <Int<kNThreadsdQ / 16>, _16>, // Thread layout, 16 threads per row
Stride< _16, _1>>
>;
using GmemTiledCopydQaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomdQaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ =
decltype(tile_to_shape(SmemLayoutAtomQ{},
make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutdO = SmemLayoutQ;
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
// using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
// decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
// using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutQt =
decltype(cute::composition(SmemLayoutQ{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
using SmemLayoutdOt =
decltype(cute::composition(SmemLayoutdO{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
using SmemLayoutKt =
decltype(cute::composition(SmemLayoutK{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
make_stride(Int<kBlockN>{}, _1{}))));
using SmemLayoutPt =
decltype(cute::composition(SmemLayoutP{},
make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
make_stride(Int<kBlockM>{}, _1{}))));
using SmemLayoutdSt =
decltype(cute::composition(SmemLayoutdS{},
make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
make_stride(Int<kBlockM>{}, _1{}))));
// using SmemLayoutdQacct =
// decltype(cute::composition(SmemLayoutdQacc{},
// make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
// make_stride(Int<kBlockM>{}, _1{}))));
using SmemLayoutdK = SmemLayoutK;
using SmemLayoutdV = SmemLayoutV;
using SmemLayoutdKt = SmemLayoutKt;
using SmemLayoutdVt = SmemLayoutKt;
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using SmemLayoutAtomdQ = decltype(
// composition(Swizzle<kSwizzle, 3, 3>{},
composition(Swizzle<3, 3, 3>{},
Layout<Shape<Int<kNThreadsdQ / 32>, Int<32>>,
Stride<Int<32>, _1>>{}));
using SmemLayoutdQ = decltype(tile_to_shape(
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemLayoutdQt =
decltype(cute::composition(SmemLayoutdQ{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
make_stride(Int<kBlockM>{}, _1{}))));
static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutdQacc = SmemLayoutdQ;
using SmemLayoutdQacct = SmemLayoutdQt;
using SmemLayoutdQacc2 = decltype(tile_to_shape(
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{}, _2{})));
// using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
// using SmemLayoutdQacct =
// decltype(cute::composition(SmemLayoutdQacc{},
// make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
// make_stride(Int<kBlockM>{}, _1{}))));
using RmemTiledCopydQacc = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomdQaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
// using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomPdS = Copy_Atom<
std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
Element>;
using SmemCopyAtomdKV = Copy_Atom<
std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
Element>;
using SmemCopyAtomdQ = Copy_Atom<
std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
Element>;
using SharedStorage = std::conditional_t<
!Is_WS,
SharedStorageQKVdOdKV<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV>,
SharedStorageQKVdOdKVWS<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV>
// SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV>
>;
// using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
// using PipelineState = typename cutlass::PipelineState<kStages * 2>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
int kClusterN_ = 1, typename elem_type=cutlass::half_t>
struct Flash_bwd_seqqpar_kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using index_t = int64_t;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static_assert(kNWarps_ == 8);
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
static constexpr int kClusterN = kClusterN_;
using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
static constexpr int kStages = 2;
static constexpr bool SdP_swapAB = SdP_swapAB_;
static constexpr bool dKV_swapAB = dKV_swapAB_;
static constexpr bool dQ_swapAB = dQ_swapAB_;
static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
using TileShapeAtomSdP = std::conditional_t<
!SdP_swapAB,
Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
>;
using AtomLayoutSdP = std::conditional_t<
!SdP_swapAB,
Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
>;
using TiledMmaSdP = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
AtomLayoutSdP{}));
using TileShapeAtomdKV = std::conditional_t<
!dKV_swapAB,
Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
>;
using AtomLayoutdKV = std::conditional_t<
!dKV_swapAB,
Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
>;
using TiledMmadKV = decltype(cute::make_tiled_mma(
std::conditional_t<
!SdP_swapAB,
decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
>{},
AtomLayoutdKV{}));
using TileShapeAtomdQ = std::conditional_t<
!dQ_swapAB,
Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
>;
using AtomLayoutdQ = std::conditional_t<
!dQ_swapAB,
Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
>;
static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
using TiledMmadQ = decltype(cute::make_tiled_mma(
std::conditional_t<
!dQ_swapAB,
std::conditional_t<
Mma_dQ_is_RS,
decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
>,
decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
>{},
AtomLayoutdQ{}));
using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
using GmemTiledCopydKV = cute::SM90_TMA_STORE;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static constexpr bool Has_cp_async = true;
#else
static constexpr bool Has_cp_async = false;
#endif
// For the dot_do_o preprocessing kernel
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
// to affect speed in practice.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopydO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomdQaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
Stride< _8, _1>>,
Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
Stride< _16, _1>>
>;
using GmemTiledCopydQaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomdQaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutdO = SmemLayoutQ;
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutQt =
decltype(cute::composition(SmemLayoutQ{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
make_stride(Int<kBlockM>{}, _1{}))));
using SmemLayoutdOt =
decltype(cute::composition(SmemLayoutdO{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
make_stride(Int<kBlockM>{}, _1{}))));
using SmemLayoutKt =
decltype(cute::composition(SmemLayoutK{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
make_stride(Int<kBlockN>{}, _1{}, Int<kBlockN * kHeadDim>{}))));
using SmemLayoutPt =
decltype(cute::composition(SmemLayoutP{},
make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
make_stride(Int<kBlockM>{}, _1{}))));
using SmemLayoutdSt =
decltype(cute::composition(SmemLayoutdS{},
make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
make_stride(Int<kBlockM>{}, _1{}))));
using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
using SmemLayoutdV = SmemLayoutdK;
using SmemLayoutdKt = SmemLayoutKt;
using SmemLayoutdVt = SmemLayoutKt;
using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{})));
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using SmemLayoutAtomdQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutdQ = decltype(tile_to_shape(
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemLayoutdQt =
decltype(cute::composition(SmemLayoutdQ{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
make_stride(Int<kBlockM>{}, _1{}))));
static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutdKV = decltype(tile_to_shape(
SmemLayoutAtomdKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemLayoutdKVt =
decltype(cute::composition(SmemLayoutdKV{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
make_stride(Int<kBlockN>{}, _1{}))));
static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2;
// using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomPdS = Copy_Atom<
std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
Element>;
using SmemCopyAtomdKV = Copy_Atom<
std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
Element>;
using SmemCopyAtomdQ = Copy_Atom<
std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
Element>;
using SharedStorage = SharedStorageQKVdOdKVSeqqPar<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQTMA>;
// using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
// using PipelineState = typename cutlass::PipelineState<kStages * 2>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "utils.h"
namespace flash {
using namespace cute;
template <typename Ktraits, bool Is_causal>
struct CollectiveMainloopFwd {
using Element = typename Ktraits::Element;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kHeadDim = Ktraits::kHeadDim;
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutV = SmemLayoutK;
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVt =
decltype(cute::composition(SmemLayoutV{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
// using SmemLayoutAtomVt = cute::GMMA::Layout_MN_SW128_Atom<Element>;
// using SmemLayoutVt =
// decltype(tile_to_shape(SmemLayoutAtomVt{},
// make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
// Step<_2, _1, _3>{})); // This gives correct results, without Step it's wrong
// using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::MN, Element,
// decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
// using SmemLayoutVt =
// decltype(tile_to_shape(SmemLayoutAtomVt{},
// make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
// using SmemLayoutAtomVTMA = cute::GMMA::Layout_K_SW128_Atom<Element>;
// using SmemLayoutVTMA =
// decltype(tile_to_shape(SmemLayoutAtomVTMA{},
// make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
using TMA_Q = decltype(make_tma_copy(
GmemTiledCopyQ{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}),
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for Q
using TMA_KV = decltype(make_tma_copy(
GmemTiledCopyKV{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}),
take<0, 2>(SmemLayoutK{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
// Host side kernel arguments
struct Arguments {
Element const* ptr_Q;
ShapeQKV const shape_Q;
StrideQKV const stride_Q;
Element const* ptr_K;
ShapeQKV const shape_K;
StrideQKV const stride_K;
Element const* ptr_V;
StrideQKV const stride_V;
float const softmax_scale_log2;
};
// Device side kernel params
struct Params {
ShapeQKV const shape_Q;
ShapeQKV const shape_K;
TMA_Q tma_load_Q;
TMA_KV tma_load_K, tma_load_V;
float const softmax_scale_log2;
};
static Params
to_underlying_arguments(Arguments const& args) {
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);
TMA_Q tma_load_Q = make_tma_copy(
GmemTiledCopyQ{},
mQ,
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for Q
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);
TMA_KV tma_load_K = make_tma_copy(
GmemTiledCopyKV{},
mK,
SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V);
TMA_KV tma_load_V = make_tma_copy(
GmemTiledCopyKV{},
mV,
SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {args.shape_Q, args.shape_K, tma_load_Q, tma_load_K, tma_load_V, args.softmax_scale_log2};
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
}
CUTLASS_DEVICE
int get_n_block_max(Params const& mainloop_params, int m_block) {
static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kBlockN = get<1>(TileShape_MNK{});
int const seqlen_q = get<0>(mainloop_params.shape_Q);
int const seqlen_k = get<0>(mainloop_params.shape_K);
int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
if constexpr (Is_causal) {
n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q, kBlockN));
}
return n_block_max;
}
template <typename FullParams, typename SchedulerParams, typename SharedStorage, typename WorkTileInfo>
CUTLASS_DEVICE void
load(FullParams const& params,
Params const& mainloop_params,
SchedulerParams const& scheduler_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k,
PipelineState& smem_pipe_write_v,
SharedStorage &shared_storage,
WorkTileInfo work_tile_info,
int& work_idx,
int& tile_count_semaphore
) {
static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kBlockN = get<1>(TileShape_MNK{});
// int const m_block = work_tile_info.M_idx;
// int const bidh = work_tile_info.H_idx;
// int const bidb = work_tile_info.B_idx;
// int m_block;
// int bidh, bidb;
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
auto [m_block, bidh, bidb] = work_tile_info.get_block_coord(scheduler_params);
// if (threadIdx.x == 0) { printf("producer, blockIdx.x = %d, bidb = %d, bidh = %d, m_block = %d\n", blockIdx.x, bidb, bidh, m_block); }
int n_block_max = get_n_block_max(mainloop_params, m_block);
if (Is_causal && n_block_max <= 0) {
// Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/);
// if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
// tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
// shared_storage.tile_count_semaphore = tile_count_semaphore;
// }
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
return;
}
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.shape_Q);
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K);
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K);
// Prepare the TMA loads
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
uint16_t mcast_mask_kv = 0;
if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
}
}
int n_block = n_block_max - 1;
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
}
// Wait for the MMA warpgroups to say that smem_q is ready
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
if (lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
}
// Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem
// Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
// TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.
shared_storage.barrier_O.wait((work_idx + 1) % 2);
if (lane_predicate) {
// CUTLASS_PRAGMA_NO_UNROLL
#pragma unroll 2
for (; n_block > 0; --n_block) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
// tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
}
if (lane_predicate) {
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
// printf("blockIdx.x = %d, tile_count_semaphore: %d\n", blockIdx.x, tile_count_semaphore);
// shared_storage.tile_count_semaphore = tile_count_semaphore;
}
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
++work_idx;
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was still inverted from make_producer_start_state
*/
pipeline_k.producer_tail(smem_pipe_write_k);
pipeline_v.producer_tail(smem_pipe_write_v);
}
}
CUTLASS_DEVICE void
scheduler_barrier_sync() {
if constexpr (UseSchedulerBarrier) {
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 3 + cutlass::canonical_warp_group_idx() /*id*/);
}
}
CUTLASS_DEVICE void
scheduler_barrier_arrive() {
if constexpr (!UseSchedulerBarrier) { return; }
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
} else {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 2 ? 3 + cutlass::canonical_warp_group_idx() + 1 : 3 + cutlass::canonical_warp_group_idx() + 1 - 3 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 1 ? 3 + cutlass::canonical_warp_group_idx() + 2 : 3 + cutlass::canonical_warp_group_idx() + 2 - 3 /*id*/);
}
}
CUTLASS_DEVICE void
mma_init() {
// Tell producer (warp 0) that smem_q is ready
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
if constexpr (!UseSchedulerBarrier) { return; }
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if (cutlass::canonical_warp_group_idx() > 1) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 1 /*id*/);
}
if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
if (cutlass::canonical_warp_group_idx() > 2) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 2 /*id*/);
}
}
}
template <typename SharedStorage, typename FrgTensorO, typename Softmax>
CUTLASS_DEVICE void
mma(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_read_k,
PipelineState& smem_pipe_read_v,
FrgTensorO& tOrO,
Softmax& softmax,
int n_block_count,
int thread_idx,
int work_idx,
int m_block,
SharedStorage& shared_storage
) {
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kBlockN = get<1>(TileShape_MNK{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
typename Ktraits::TiledMma0 tiled_mma0;
typename Ktraits::TiledMma1 tiled_mma1;
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
// Allocate "fragments/descriptors" for first matmul.
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
Tensor tSrK = threadMma0.partition_fragment_B(sK);
// Allocate "fragments/descriptors" for second matmul.
// Note: S becomes P.
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
int const seqlen_q = get<0>(mainloop_params.shape_Q);
int const seqlen_k = get<0>(mainloop_params.shape_K);
int n_block = n_block_count - 1;
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
scheduler_barrier_arrive();
if (work_idx != 0) {
int lane_predicate = cute::elect_one_sync();
if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
tma_store_wait<0>();
#pragma unroll
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
shared_storage.barrier_O.arrive(cta_id, lane_predicate);
}
}
}
warpgroup_wait<0>();
pipeline_k.consumer_release(smem_pipe_read_k);
++smem_pipe_read_k;
auto col_limit_causal = [&](int row, int n_block) {
return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
};
{
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if constexpr (!Is_causal) { // Just masking based on col
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
} else { // mask based on both row and col
// using std::min is faster than doing col >= limit0 or col >= limit1
// Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the
// right hand side can be negative and might be converted to a very large unsigned integer.
if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN,
col_limit_causal(int(get<0>(tScS(i))), n_block))) {
tSrS(i) = -INFINITY;
}
}
}
}
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
Tensor scores_scale = make_fragment_like(softmax.row_max);
clear(scores_scale);
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
// Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
consumer_wait(pipeline_v, smem_pipe_read_v);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
scheduler_barrier_arrive();
warpgroup_wait<1>();
pipeline_k.consumer_release(smem_pipe_read_k); // release K
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block - 1)) {
tSrS(i) = -INFINITY;
}
}
cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v); // release V
++smem_pipe_read_k;
++smem_pipe_read_v;
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
}
#pragma unroll 1
for (; n_block > 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
scheduler_barrier_arrive();
warpgroup_wait<1>();
pipeline_k.consumer_release(smem_pipe_read_k); // release K
// auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v); // release V
++smem_pipe_read_k;
++smem_pipe_read_v;
// softmax.rescale_o(tOrO, scores_scale);
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
}
// Tell warp 0 that smem_q is ready
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang
++smem_pipe_read_v;
softmax.rescale_o(tOrO, scores_scale);
return;
}
};
} // namespace flash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment