Commit 7f67966c authored by Tri Dao's avatar Tri Dao
Browse files

FA3 initial code release

parent b4a9dd6c
...@@ -26,6 +26,42 @@ contains a partial list of places where FlashAttention is being used. ...@@ -26,6 +26,42 @@ contains a partial list of places where FlashAttention is being used.
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite and credit FlashAttention if you use it. Please cite and credit FlashAttention if you use it.
## FlashAttention-3 beta release
FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).
Blogpost: https://tridao.me/blog/2024/flash3/
Paper: https://tridao.me/publications/flash3/flash3.pdf
![FlashAttention-3 speedup on H100 80GB SXM5 with FP16](assets/flash3_fp16_fwd.png)
This is a beta release for testing / benchmarking before we integrate that with
the rest of the repo.
Currently released:
- FP16 forward and backward
Coming soon in the next couple of days / next week:
- BF16
- Variable length (FP16, BF16)
- FP8 forward.
Requirements: H100 / H800 GPU, CUDA >= 12.3.
To install:
```sh
cd hopper
python setup.py install
```
To run the test:
```sh
export PYTHONPATH=$PWD
pytest -q -s test_flash_attn.py
```
## Installation and features ## Installation and features
Requirements: Requirements:
......
__version__ = "3.0.0.b1"
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
template <typename index_t>
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "utils.h"
namespace flash {
using namespace cute;
// template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
template <typename Ktraits>
struct CollectiveEpilogueFwd {
using Element = typename Ktraits::Element;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kHeadDim = Ktraits::kHeadDim;
// using Element = Element_;
// static constexpr int kBlockM = kBlockM_;
// static constexpr int kBlockN = kBlockN_;
// static constexpr int kHeadDim = kHeadDim_;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
// static constexpr int kNWarps = kNWarps_;
static constexpr int kNWarps = Ktraits::kNWarps;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr bool Is_WS = kNWarps >= 12;
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
// These are for storing the output tensor without TMA (e.g., for setting output to zero)
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
static constexpr int kGmemThreadsPerRow = kHeadDim / kGmemElemsPerLoad;
static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen_q, head, batch)
using TMA_O = decltype(make_tma_copy(
GmemTiledCopyOTMA{},
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), repeat_like(StrideO{}, int32_t(0)), StrideO{}),
SmemLayoutO{},
select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for O
// Host side kernel arguments
struct Arguments {
Element* ptr_O;
ShapeO const shape_O;
StrideO const stride_O;
float* ptr_LSE;
StrideLSE const stride_LSE;
};
// Device side kernel params
struct Params {
Element* ptr_O;
ShapeO const shape_O;
StrideO const stride_O;
float* ptr_LSE;
StrideLSE const stride_LSE;
TMA_O tma_store_O;
};
static Params
to_underlying_arguments(Arguments const& args) {
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
TMA_O tma_store_O = make_tma_copy(
GmemTiledCopyOTMA{},
mO,
SmemLayoutO{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for O
return {args.ptr_O, args.shape_O, args.stride_O, args.ptr_LSE, args.stride_LSE, tma_store_O};
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& epilogue_params) {
cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
}
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
CUTLASS_DEVICE void
store(Params const& epilogue_params,
FrgTensorO const& tOrO,
FrgTensorLSE const& lse,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
) {
auto [m_block, bidh, bidb] = block_coord;
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOrO_out = flash::convert_type<Element>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Make sure all WGs have finished reading V
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.shape_O);
Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{});
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
// taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO_row(_0{})) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}
if (cutlass::canonical_warp_idx_sync() == kNWarps - 1) {
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
int const lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
cute::copy(epilogue_params.tma_store_O, tOsO, tOgO);
tma_store_arrive();
}
}
}
CUTLASS_DEVICE void
store_tail() {
tma_store_wait<0>();
}
// Write 0 to output and -inf to LSE
CUTLASS_DEVICE void
store_zero(
Params const& epilogue_params,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
) {
auto [m_block, bidh, bidb] = block_coord;
Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.shape_O, epilogue_params.stride_O);
Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_fragment_like(tOgO);
clear(tOrO);
// Construct identity layout for sO
Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.shape_O); }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.shape_O) - m_block * kBlockM
);
static_assert(kBlockM <= NumMmaThreads);
if (thread_idx < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
}
};
} // namespace flash
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#include "cutlass/fast_math.h" // For cutlass::FastDivmod
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void * __restrict__ p_ptr;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
cutlass::FastDivmod head_divmod, m_block_divmod;
int total_blocks;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
uint32_t scale_softmax_log2_half2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
int *__restrict__ blockmask;
// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;
// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
int * __restrict__ cache_batch_idx;
// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;
// Local window size
int window_size_left, window_size_right;
// Random state.
at::PhiloxCudaState philox_args;
// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;
bool is_bf16;
bool is_e4m3;
bool is_causal;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
bool is_rotary_interleaved;
int num_splits; // For split-KV version
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
int * __restrict__ tile_count_semaphore;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_bwd_params : public Flash_fwd_params {
// The dO and dQKV matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;
// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
int *__restrict__ dq_semaphore;
bool deterministic;
index_t dq_accum_split_stride;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
This diff is collapsed.
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Union
import torch
import torch.nn as nn
# isort: off
# We need to import the CUDA kernels after importing torch
import flashattn_hopper_cuda
# isort: on
def _flash_attn_forward(q, k, v, softmax_scale, causal):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
q,
k,
v,
None,
softmax_scale,
causal,
)
return out, q, k, v, out_padded, softmax_lse, S_dmask
def _flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
causal
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
causal,
)
return dq, dk, dv, softmax_d
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
softmax_scale,
causal,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
q,
k,
v,
softmax_scale,
causal
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out, softmax_lse
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
ctx.softmax_scale,
ctx.causal,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None
def flash_attn_func(
q,
k,
v,
softmax_scale=None,
causal=False,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(
q,
k,
v,
softmax_scale,
causal,
)
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t>(params, stream);
}
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "cute/tensor.hpp"
#include "cutlass/cluster_launch.hpp"
#include "static_switch.h"
#include "flash.h"
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h"
#include "kernel_traits.h"
template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
}
// template<typename Kernel_traits>
// __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
// flash::convert_dQ<Kernel_traits>(params, nsplits);
// }
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
flash::convert_dKV<Kernel_traits>(params);
}
template<typename Kernel_traits, bool Is_causal>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
dim3 grid_m(num_m_block, params.b, params.h);
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
// If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum
// flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)),
make_shape(params.seqlen_q, params.d, params.h, params.b),
make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride));
auto tma_load_Q = make_tma_copy(
typename Kernel_traits::GmemTiledCopyQdO{},
mQ,
typename Kernel_traits::SmemLayoutQ{}(_, _, _0{}),
// typename Kernel_traits::SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
Tensor mdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.do_ptr)),
make_shape(params.seqlen_q, params.d, params.h, params.b),
make_stride(params.do_row_stride, _1{}, params.do_head_stride, params.do_batch_stride));
auto tma_load_dO = make_tma_copy(
typename Kernel_traits::GmemTiledCopyQdO{},
mdO,
typename Kernel_traits::SmemLayoutdO{}(_, _, _0{}),
// typename Kernel_traits::SmemLayoutdO{},
select<0, 2>(TileShape_MNK{}),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride));
auto tma_load_K = make_tma_copy(
typename Kernel_traits::GmemTiledCopyKV{},
mK,
typename Kernel_traits::SmemLayoutK{},
// typename Kernel_traits::SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
_1{}); // no mcast for K
Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride));
auto tma_load_V = make_tma_copy(
typename Kernel_traits::GmemTiledCopyKV{},
mV,
typename Kernel_traits::SmemLayoutV{},
// typename Kernel_traits::SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
_1{}); // no mcast for V
Tensor mdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.dk_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride));
auto tma_store_dK = make_tma_copy(
typename Kernel_traits::GmemTiledCopydKV{},
mdK,
typename Kernel_traits::SmemLayoutdK{},
select<1, 2>(TileShape_MNK{}),
_1{}); // no mcast for output
Tensor mdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.dv_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride));
auto tma_store_dV = make_tma_copy(
typename Kernel_traits::GmemTiledCopydKV{},
mdV,
typename Kernel_traits::SmemLayoutdV{},
select<1, 2>(TileShape_MNK{}),
_1{}); // no mcast for output
Tensor mdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.dq_ptr)),
make_shape(params.seqlen_q, params.d, params.h, params.b),
make_stride(params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride));
Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dq_accum_ptr)),
make_shape(params.seqlen_q, params.d, params.h, params.b),
make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_q_rounded));
auto tma_store_dQaccum = make_tma_copy(
// typename Kernel_traits::GmemTiledCopydKV{},
typename cute::SM90_TMA_STORE{},
// mdQ,
mdQaccum,
// typename Kernel_traits::SmemLayoutdQTMA{},
typename Kernel_traits::SmemLayoutdQaccTMA{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for output
auto tma_reduce_add_dQaccum = make_tma_copy(
// typename Kernel_traits::GmemTiledCopydKV{},
typename cute::SM90_TMA_REDUCE_ADD{},
// mdQ,
mdQaccum,
// typename Kernel_traits::SmemLayoutdQTMA{},
typename Kernel_traits::SmemLayoutdQaccTMA{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for output
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
// print(typename Kernel_traits::TiledMmaSdP{}); printf("\n");
// print(typename Kernel_traits::TiledMmadKV{}); printf("\n");
// print(typename Kernel_traits::TiledMmadQ{}); printf("\n");
// print(typename Kernel_traits::SmemLayoutAtomK{}); printf("\n");
// print(typename Kernel_traits::SmemLayoutK{}); printf("\n");
// print(typename Kernel_traits::SmemLayoutKt{}); printf("\n");
// Get the ptr to kernel function.
void *kernel;
if constexpr (!Kernel_traits::Is_WS) {
kernel = (void *)flash::compute_dqkv<Kernel_traits, Is_causal, decltype(tma_load_Q), decltype(tma_load_dO),
decltype(tma_load_K), decltype(tma_load_V), decltype(tma_store_dK), decltype(tma_store_dV)>;
} else {
kernel = (void *)flash::compute_dqkv_ws<Kernel_traits, Is_causal, decltype(tma_load_Q), decltype(tma_load_dO),
decltype(tma_load_K), decltype(tma_load_V), decltype(tma_store_dK), decltype(tma_store_dV), decltype(tma_store_dQaccum), decltype(tma_reduce_add_dQaccum)>;
}
// void *kernel = (void *)flash::compute_dqkv_seqqpar<Kernel_traits, Is_causal, decltype(tma_load_Q), decltype(tma_load_dO),
// decltype(tma_load_K), decltype(tma_load_V), decltype(tma_store_dQaccum), decltype(tma_store_dK), decltype(tma_store_dV)>;
auto shared_storage = typename Kernel_traits::SharedStorage{};
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
int smem_size_q = sizeof(decltype(shared_storage.smem_q));
int smem_size_do = sizeof(decltype(shared_storage.smem_do));
int smem_size_k = sizeof(decltype(shared_storage.smem_k));
int smem_size_v = sizeof(decltype(shared_storage.smem_v));
// int smem_size_p = sizeof(decltype(shared_storage.smem_p));
int smem_size_ds = sizeof(decltype(shared_storage.smem_ds));
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds);
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds);
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
int num_blocks_n = cutlass::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
num_blocks_n = cutlass::ceil_div(num_blocks_n, size<1>(ClusterShape{})) * size<1>(ClusterShape{});
dim3 grid_dims(num_blocks_n, params.h, params.b);
// int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
// num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
// dim3 grid_dims(num_blocks_m, params.h, params.b);
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
if constexpr (!Kernel_traits::Is_WS) {
cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
tma_load_K, tma_load_V, tma_store_dK, tma_store_dV);
} else {
cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
tma_load_K, tma_load_V, tma_store_dK, tma_store_dV, tma_store_dQaccum, tma_reduce_add_dQaccum);
}
// cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
// tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto tma_load_dQaccum = make_tma_copy(
typename cute::SM90_TMA_LOAD{},
mdQaccum,
typename Kernel_traits::SmemLayoutdQaccTMA{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for output
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
auto kernel_dq = &flash::convert_dQ<Kernel_traits, decltype(tma_load_dQaccum)>;
if (Kernel_traits::kSmemdQSize * 2 + 8 >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize * 2 + 8));
}
kernel_dq<<<grid_m, Kernel_traits::kNThreadsdQ, Kernel_traits::kSmemdQSize * 2 + 8, stream>>>(params, tma_load_dQaccum);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
// if (Kernel_traits::kSmemdKVSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize));
// }
// int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
// dim3 grid_n(num_n_block, params.b, params.h);
// kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemdKVSize, stream>>>(params);
// C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename T>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// run_flash_bwd<T, Headdim, Is_causal>(params, stream);
// });
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, false, false, false, 2, 2, 2, 1, T>, false>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 12, true, false, false, 1, 2, 2, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 96, 128, 12, true, false, true, 1, 2, 2, 1, T>, false>(params, stream);
}
template<typename T>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// run_flash_bwd<T, Headdim, Is_causal>(params, stream);
// });
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, 2, 1, 2, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, false, false, false, 1, 2, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 96, 8, false, true, false, 2, 1, 2, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 96, 8, false, true, true, 2, 1, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 12, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 12, true, false, false, 1, 2, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 12, false, false, false, 1, 2, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 80, 128, 12, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_seqqpar_kernel_traits<Headdim, 128, 64, 8, false, true, false, 2, 1, 2, 1, T>, false>(params, stream);
// run_flash_bwd<Flash_bwd_seqqpar_kernel_traits<Headdim, 96, 128, 8, true, false, true, 1, 2, 1, 1, T>, false>(params, stream);
}
template<typename T>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
// constexpr static int Headdim = 256;
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// run_flash_bwd<T, Headdim, Is_causal>(params, stream);
// });
}
This diff is collapsed.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
}
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/tensor.hpp"
#include <cutlass/cutlass.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "flash.h"
#include "utils.h"
#include "softmax.h"
#include "tile_scheduler.hpp"
#include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
#include "epilogue_fwd_sm90_tma.hpp"
namespace flash {
using namespace cute;
template <typename Ktraits, bool Is_causal, typename TileScheduler>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_attn_ws(CUTE_GRID_CONSTANT Flash_fwd_params const params,
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal>::Params const mainloop_params,
CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits>::Params const epilogue_params,
CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params
) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static_assert(Ktraits::Is_WS);
static constexpr bool Is_WS = Ktraits::Is_WS;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
// static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim;
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal>;
using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1 /*numThreads*/);
shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
}
// We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue;
// We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 24 : 32>();
// cutlass::arch::warpgroup_reg_dealloc<56>();
// StaticPersistentTileScheduler scheduler{params.m_block_divmod, params.head_divmod, params.total_blocks};
// auto work_tile_info = scheduler.get_current_work();
TileScheduler scheduler;
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
int work_idx = 0;
// auto get_tile_count = [&] () {
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
// return shared_storage.tile_count_semaphore;
// };
// while (work_tile_info.is_valid()) {
// for (int tile_count = blockIdx.x; tile_count < params.total_blocks; tile_count = get_tile_count()) {
// for (int tile_count_semaphore = blockIdx.x; tile_count_semaphore < params.total_blocks; tile_count_semaphore = __shfl_sync(0xffffffff, tile_count_semaphore, 0)) {
for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) {
int tile_count_semaphore = 0;
collective_mainloop.load(params, mainloop_params, scheduler_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
shared_storage, work_tile_info, work_idx, tile_count_semaphore);
// ++work_idx;
// work_tile_info = scheduler.fetch_next_work();
}
collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
}
} else { // Consumer
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 240 : 160>();
// cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 224 : 160>();
// Initialize matmul objects.
typename Ktraits::TiledMma1 tiled_mma1;
TileScheduler scheduler{};
PipelineState smem_pipe_read_k, smem_pipe_read_v;
// We don't need separate variables smem_pip_release_k and smem_pipe_release_v
// (like in Cutlass's gemm) because the read and release pipeline states are always the same.
auto get_tile_count = [&] () {
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
return shared_storage.tile_count_semaphore;
};
collective_mainloop.mma_init();
int work_idx = 0;
CUTLASS_PRAGMA_NO_UNROLL
// for (int work_idx = 0; work_idx * gridDim.x + blockIdx.x < params.total_blocks; ++work_idx) {
// for (int tile_count_semaphore = blockIdx.x, work_idx = 0; tile_count_semaphore < params.total_blocks; tile_count_semaphore = get_tile_count()) {
for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) {
// Attention output (GEMM-II) accumulator.
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
// int m_block;
// int bidh, bidb;
// // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, work_idx * gridDim.x + blockIdx.x));
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
// cute::tuple<int32_t, int32_t, int32_t> block_coord = {m_block, bidh, bidb};
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord;
int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block);
if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
// Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/);
collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord);
continue;
}
collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage);
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, 0, shared_storage);
collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
threadIdx.x - NumCopyThreads, block_coord);
++work_idx;
// work_tile_info = scheduler.fetch_next_work();
}
collective_epilogue.store_tail();
}
}
} // namespace flash
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "cute/tensor.hpp"
#include "cutlass/cluster_launch.hpp"
#include "static_switch.h"
#include "flash.h"
#include "tile_scheduler.hpp"
#include "flash_fwd_kernel.h"
#include "kernel_traits.h"
template<typename Kernel_traits, bool Is_causal>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
using Element = typename Kernel_traits::Element;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal>;
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits>;
// using Scheduler = flash::SingleTileScheduler;
using Scheduler = flash::StaticPersistentTileScheduler;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(params.q_ptr),
{params.seqlen_q, params.d, params.h, params.b}, // shape_Q
{params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride}, // stride_Q
static_cast<Element const*>(params.k_ptr),
{params.seqlen_k, params.d, params.h_k, params.b}, // shape_K
{params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride}, // stride_K
static_cast<Element const*>(params.v_ptr),
{params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride}, // stride_V
params.scale_softmax_log2
});
typename CollectiveEpilogue::Params epilogue_params =
CollectiveEpilogue::to_underlying_arguments({
static_cast<Element*>(params.o_ptr),
{params.seqlen_q, params.d, params.h, params.b}, // shape_O
{params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride}, // stride_O
static_cast<float*>(params.softmax_lse_ptr),
{_1{}, params.seqlen_q, params.h * params.seqlen_q}, // stride_LSE
});
int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b};
typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
// Get the ptr to kernel function.
void *kernel;
kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
params.m_block_divmod = cutlass::FastDivmod(num_blocks_m);
params.total_blocks = num_blocks_m * params.h * params.b;
// dim3 grid_dims(num_blocks_m, params.h, params.b);
// dim3 grid_dims(132);
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, 132);
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, params, mainloop_params, epilogue_params, scheduler_params);
// kernel<<<grid_dims, block_dims, smem_size, stream>>>(params, tma_load_Q, tma_load_K, tma_load_V, tma_store_O);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>, Is_causal>(params, stream);
});
}
template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal ? 2 : 1, T>, Is_causal>(params, stream);
});
}
template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal ? 2 : 1, T>, Is_causal>(params, stream);
});
}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment