Commit 37c6e054 authored by Tri Dao's avatar Tri Dao
Browse files

Implement flash_attn_with_kvcache

parent 4976650f
...@@ -102,6 +102,7 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -102,6 +102,7 @@ void set_params_fprop(Flash_fwd_params &params,
TORCH_CHECK(p_dropout < 1.f); TORCH_CHECK(p_dropout < 1.f);
params.is_causal = is_causal; params.is_causal = is_causal;
params.is_seqlens_k_cumulative = true;
} }
void set_params_dgrad(Flash_bwd_params &params, void set_params_dgrad(Flash_bwd_params &params,
...@@ -175,10 +176,10 @@ void set_params_dgrad(Flash_bwd_params &params, ...@@ -175,10 +176,10 @@ void set_params_dgrad(Flash_bwd_params &params,
params.dsoftmax_sum = dsoftmax_sum_d; params.dsoftmax_sum = dsoftmax_sum_d;
} }
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] { FP16_SWITCH(!params.is_bf16, [&] {
FWD_HEADDIM_SWITCH(params.d, [&] { FWD_HEADDIM_SWITCH(params.d, [&] {
if (params.num_splits <= 1) { // If we don't set it num_splits == 0 if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim>(params, stream); run_mha_fwd_<elem_type, kHeadDim>(params, stream);
} else { } else {
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream); run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
...@@ -350,7 +351,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head ...@@ -350,7 +351,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const int num_m_blocks = (seqlen_q + 64 - 1) / 64; const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
params.num_splits = 1; params.num_splits = 1;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 64); params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
if (params.num_splits > 1) { if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
...@@ -990,10 +991,198 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -990,10 +991,198 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
return { dq, dk, dv, softmax_d }; return { dq, dk, dv, softmax_d };
} }
std::vector<at::Tensor>
mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_q x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_q x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
const bool is_causal,
int num_splits
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
}
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(kcache.is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(vcache.is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
const int num_heads = sizes[2];
const int head_size_og = sizes[3];
const int seqlen_k = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og);
at::Tensor q_padded, kcache_padded, vcache_padded;
if (head_size_og % 8 != 0) {
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
q_padded = q;
kcache_padded = kcache;
vcache_padded = vcache;
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q_padded, kcache_padded, vcache_padded, out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*p_ptr=*/nullptr,
softmax_lse.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
is_causal);
at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) {
TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
k = k_.value();
v = v_.value();
TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
TORCH_CHECK(k.is_cuda(), "Key tensor must be on CUDA device");
TORCH_CHECK(v.is_cuda(), "Value tensor must be on CUDA device");
TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
CHECK_SHAPE(k, batch_size, seqlen_q, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_q, num_heads_k, head_size_og);
if (head_size_og % 8 != 0) {
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
k_padded = k;
v_padded = v;
}
params.knew_ptr = k_padded.data_ptr();
params.vnew_ptr = v_padded.data_ptr();
// All stride are in elements, not bytes.
params.knew_batch_stride = k_padded.stride(0);
params.vnew_batch_stride = v_padded.stride(0);
params.knew_row_stride = k_padded.stride(-3);
params.vnew_row_stride = v_padded.stride(-3);
params.knew_head_stride = k_padded.stride(-2);
params.vnew_head_stride = v_padded.stride(-2);
}
if (seqlens_k_.has_value()) {
auto seqlens_k = seqlens_k_.value();
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(seqlens_k.is_cuda(), "seqlens_k must be on CUDA device");
TORCH_CHECK(seqlens_k.is_contiguous(), "seqlens_k must be contiguous");
CHECK_SHAPE(seqlens_k, batch_size);
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
}
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = is_sm90 || is_sm8x
? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64))
: (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64));
const int num_n_blocks = (seqlen_k + (params.knew_ptr == nullptr ? 0 : seqlen_q) + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
params.num_splits = num_splits;
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
}
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
auto stream = at::cuda::getCurrentCUDAStream().stream();
// Only split kernel supports appending to KV cache
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value());
if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
if (k_.has_value()) {
// It's expensive to copy the KV cache here for the case where head size not divisible by 8,
// but we don't expect to get this case in practice. This is just so that the code works for that case.
kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
}
}
return {out, softmax_lse};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention"; m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass"); m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass"); m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
} }
...@@ -14,9 +14,12 @@ struct BlockInfo { ...@@ -14,9 +14,12 @@ struct BlockInfo {
template<typename Params> template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb) __device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q))
{ {
} }
...@@ -33,6 +36,8 @@ struct BlockInfo { ...@@ -33,6 +36,8 @@ struct BlockInfo {
const int sum_s_q; const int sum_s_q;
const int sum_s_k; const int sum_s_k;
const int actual_seqlen_q; const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k; const int actual_seqlen_k;
}; };
......
...@@ -80,6 +80,18 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -80,6 +80,18 @@ struct Flash_fwd_params : public Qkv_params {
int *__restrict__ blockmask; int *__restrict__ blockmask;
// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;
// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
// The dropout probability (probability of keeping an activation). // The dropout probability (probability of keeping an activation).
float p_dropout; float p_dropout;
// uint32_t p_dropout_in_uint; // uint32_t p_dropout_in_uint;
...@@ -99,6 +111,10 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -99,6 +111,10 @@ struct Flash_fwd_params : public Qkv_params {
bool is_bf16; bool is_bf16;
bool is_causal; bool is_causal;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
int num_splits; // For split-KV version int num_splits; // For split-KV version
}; };
......
...@@ -617,7 +617,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -617,7 +617,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -635,7 +635,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -635,7 +635,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int kNWarps = Kernel_traits::kNWarps;
using GmemTiledCopyO = std::conditional_t<
!Split,
typename Kernel_traits::GmemTiledCopyOaccum,
typename Kernel_traits::GmemTiledCopyO
>;
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_q = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q)); }
if (m_block * kBlockM >= binfo.actual_seqlen_q) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
...@@ -649,19 +658,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -649,19 +658,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We exit early and write 0 to gOaccum and -inf to gLSEaccum. // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
// Otherwise we might read OOB elements from gK and gV, // Otherwise we might read OOB elements from gK and gV,
// or get wrong results when we combine gOaccum from different blocks. // or get wrong results when we combine gOaccum from different blocks.
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
+ m_block * kBlockM) * params.d_rounded; + m_block * kBlockM) * params.d_rounded;
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum), Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{}); make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum), Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
Shape<Int<kBlockM>>{}, Stride<_1>{}); Shape<Int<kBlockM>>{}, Stride<_1>{});
typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum)); Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
clear(tOrOaccum); clear(tOrOaccum);
// Construct identity layout for sO // Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
...@@ -679,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -679,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
#pragma unroll #pragma unroll
for (int m = 0; m < size<1>(tOgOaccum); ++m) { for (int m = 0; m < size<1>(tOgOaccum); ++m) {
const int row = get<0>(tOcO(0, m, 0)); const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = -INFINITY; } if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; }
} }
return; return;
} }
...@@ -695,6 +706,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -695,6 +706,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q), Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
...@@ -702,15 +717,26 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -702,15 +717,26 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k), Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{})); make_stride(params.k_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v), Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{})); make_stride(params.v_row_stride, _1{}));
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
// This maps to accessing the first 64 rows of knew_ptr.
Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
+ row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.knew_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.vnew_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{}); typename Kernel_traits::SmemLayoutQ{});
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
...@@ -721,8 +747,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -721,8 +747,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
typename Kernel_traits::TiledMma tiled_mma; typename Kernel_traits::TiledMma tiled_mma;
...@@ -787,32 +815,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -787,32 +815,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
if (Kernel_traits::Share_Q_K_smem) {
flash::cp_async_wait<0>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
}
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K>(
binfo.actual_seqlen_k - n_block * kBlockN); gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
cute::cp_async_fence(); cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { // flash::cp_async_wait<0>();
flash::cp_async_wait<1>(); // __syncthreads();
__syncthreads(); // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); // __syncthreads();
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
}
clear(acc_o); clear(acc_o);
...@@ -834,19 +849,37 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -834,19 +849,37 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if constexpr (Append_KV) {
// if (cute::thread0()) { print(tKgK); }
// if (cute::thread0()) { print(tKsK); }
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
flash::copy_w_min_idx<Is_even_K>(
tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
}
// __syncthreads();
// if (cute::thread0()) { print(tKgK); }
// __syncthreads();
}
// Advance gV // Advance gV
if (masking_step > 0) { if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
); );
} }
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K smem_thr_copy_Q, smem_thr_copy_K
); );
...@@ -869,19 +902,39 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -869,19 +902,39 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
// __syncthreads();
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
if constexpr (Append_KV) {
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
flash::copy_w_min_idx<Is_even_K>(
tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
}
}
if (n_block > n_block_min) { if (n_block > n_block_min) {
// Advance gK // Advance gK
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
binfo.seqlen_k_cache - (n_block - 1) * kBlockN
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
} }
// TODO: when we have key_padding_mask we'll need to Check_inf // We have key_padding_mask so we'll need to Check_inf
masking_step == 0 masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert scores from fp32 to fp16/bf16 // Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
...@@ -905,22 +958,45 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -905,22 +958,45 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
clear(acc_s); clear(acc_s);
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if constexpr (Append_KV) {
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
flash::copy_w_min_idx<Is_even_K>(
tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
}
}
// Advance gV // Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
);
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K smem_thr_copy_Q, smem_thr_copy_K
); );
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if constexpr (Append_KV) {
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
flash::copy_w_min_idx<Is_even_K>(
tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
}
}
if (n_block > n_block_min) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
binfo.seqlen_k_cache - (n_block - 1) * kBlockN
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -942,49 +1018,60 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -942,49 +1018,60 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// if (cute::thread0()) { print(acc_o_rowcol); }
Tensor lse = make_fragment_like(scores_sum); Tensor lse = make_fragment_like(scores_sum);
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = scores_sum(mi); float sum = scores_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? -INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum);
float scale = inv_sum; float scale = inv_sum;
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
} }
// if (cute::thread0()) { print(lse); }
// if (cute::thread0()) { print(acc_o_rowcol); } // if (cute::thread0()) { print(acc_o_rowcol); }
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning // Partition sO to match the accumulator partitioning
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomOaccum{}, tiled_mma); using SmemTiledCopyO = std::conditional_t<
!Split,
typename Kernel_traits::SmemCopyAtomO,
typename Kernel_traits::SmemCopyAtomOaccum
>;
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(acc_o); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor rO = flash::convert_type<ElementO>(acc_o);
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here. // sOaccum is larger than sQ, so we need to syncthreads here
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } // TODO: allocate enough smem for sOaccum
if constexpr (Split) { __syncthreads(); }
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
+ m_block * kBlockM) * params.d_rounded; + m_block * kBlockM) * params.d_rounded;
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum), Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{}); make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum), Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
Shape<Int<kBlockM>>{}, Stride<_1>{}); Shape<Int<kBlockM>>{}, Stride<_1>{});
// if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
__syncthreads(); __syncthreads();
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum)); Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
...@@ -1014,6 +1101,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1014,6 +1101,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
); );
// __syncthreads();
// if (cute::thread0()) { print(tOgOaccum); }
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -1039,16 +1128,16 @@ inline __device__ void compute_attn(const Params &params) { ...@@ -1039,16 +1128,16 @@ inline __device__ void compute_attn(const Params &params) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_splitkv(const Params &params) { inline __device__ void compute_attn_splitkv(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.z / params.h; const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.z - bidb * params.h; const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
const int n_split_idx = blockIdx.y; const int n_split_idx = Split ? blockIdx.y : 0;
const int num_n_splits = gridDim.y; const int num_n_splits = Split ? gridDim.y : 1;
flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K>(params, bidb, bidh, m_block, n_split_idx, num_n_splits); flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -15,9 +15,9 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) { ...@@ -15,9 +15,9 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params); flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params);
} }
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K> template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K>(params); flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params);
} }
template<typename Kernel_traits, int Log_max_splits, bool Is_even_K> template<typename Kernel_traits, int Log_max_splits, bool Is_even_K>
...@@ -63,16 +63,22 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -63,16 +63,22 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
template<typename Kernel_traits> template<typename Kernel_traits>
void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
constexpr size_t smem_size = Kernel_traits::kSmemSize; constexpr size_t smem_size = Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.num_splits, params.b * params.h); dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
// TODO: do we want to guarantee that seqlen_q <= seqlen_k? That would simplify the kernel a bit.
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst, IsEvenKConst>; BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>; // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if (smem_size >= 48 * 1024) { if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
...@@ -83,6 +89,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -83,6 +89,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
}); });
}); });
}); });
});
});
if (params.num_splits > 1) {
dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16);
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) { if (params.num_splits <= 2) {
...@@ -97,11 +106,12 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -97,11 +106,12 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); flash_fwd_splitkv_combine_kernel<Kernel_traits, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) { } else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); flash_fwd_splitkv_combine_kernel<Kernel_traits, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
// } else if (params.num_splits <= 128) { } else if (params.num_splits <= 128) {
// flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} }
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
}); });
}
} }
template<typename T, int Headdim> template<typename T, int Headdim>
......
...@@ -291,7 +291,7 @@ template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bo ...@@ -291,7 +291,7 @@ template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bo
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S, inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) { Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
...@@ -355,4 +355,71 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const ...@@ -355,4 +355,71 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_2_sources=false, bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S0,
Tensor<Engine0, Layout0> const &S1,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K,
const int max_MN=0, const int row_idx_switch=0) {
CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
#pragma unroll
for (int m = 0; m < size<1>(S0); ++m) {
auto &S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1;
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S0); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K,
const int max_MN=0, const int min_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(S(_, m, k), D(_, m, k));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash } // namespace flash
\ No newline at end of file
...@@ -7,4 +7,5 @@ from flash_attn.flash_attn_interface import ( ...@@ -7,4 +7,5 @@ from flash_attn.flash_attn_interface import (
flash_attn_varlen_func, flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func, flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
) )
...@@ -5,6 +5,7 @@ from einops import rearrange ...@@ -5,6 +5,7 @@ from einops import rearrange
# isort: off # isort: off
# We need to import the CUDA kernels after importing torch # We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda import flash_attn_2_cuda as flash_attn_cuda
# isort: on # isort: on
...@@ -790,3 +791,74 @@ def flash_attn_varlen_func( ...@@ -790,3 +791,74 @@ def flash_attn_varlen_func(
causal, causal,
return_attn_probs, return_attn_probs,
) )
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
cache_seqlens=None,
softmax_scale=None,
causal=False,
num_splits=0,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Does not support backward pass.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
cache_seqlens: (batch_size,), dtype torch.int32. The sequence lengths of the KV cache.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
q, k_cache, v_cache, k, v, cache_seqlens, None, softmax_scale, causal, num_splits
)
return out
...@@ -348,8 +348,14 @@ def decode_speculative( ...@@ -348,8 +348,14 @@ def decode_speculative(
) )
def sample_tokens( def sample_tokens(
input_ids, model, inference_params, sample_fn, num_tokens=1, cg=False, decoding=True, input_ids,
last_token_logits=False model,
inference_params,
sample_fn,
num_tokens=1,
cg=False,
decoding=True,
last_token_logits=False,
): ):
"""Sample `num_tokens` tokens from the model, given the previous logits. """Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens. Also return the logits of the sampled tokens.
...@@ -374,12 +380,18 @@ def decode_speculative( ...@@ -374,12 +380,18 @@ def decode_speculative(
sequences = [] sequences = []
if decoding: if decoding:
assert seqlen == 1 assert seqlen == 1
position_ids = torch.full( position_ids = repeat(
(batch_size, 1), torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
inference_params.sequence_len_offset, + inference_params.sequence_len_offset,
dtype=torch.long, "s -> b s",
device=input_ids.device, b=batch_size,
) )
# position_ids = torch.full(
# (batch_size, 1),
# inference_params.sequence_len_offset,
# dtype=torch.long,
# device=input_ids.device,
# )
else: else:
position_ids = None position_ids = None
logits = logits_postprocess_fn( logits = logits_postprocess_fn(
...@@ -399,7 +411,11 @@ def decode_speculative( ...@@ -399,7 +411,11 @@ def decode_speculative(
) )
logits = logits_postprocess_fn( logits = logits_postprocess_fn(
logits_forward_fn( logits_forward_fn(
model, rearrange(next_token, "b -> b 1"), position_ids, inference_params, cg=cg model,
rearrange(next_token, "b -> b 1"),
position_ids,
inference_params,
cg=cg,
) )
) )
inference_params.sequence_len_offset += 1 inference_params.sequence_len_offset += 1
...@@ -420,7 +436,7 @@ def decode_speculative( ...@@ -420,7 +436,7 @@ def decode_speculative(
sample_fn=sample_fn, sample_fn=sample_fn,
last_token_logits=True, last_token_logits=True,
inference_params=inference_params_draft, inference_params=inference_params_draft,
cg=cg cg=cg,
) )
if debug: if debug:
......
...@@ -11,6 +11,7 @@ from flash_attn import ( ...@@ -11,6 +11,7 @@ from flash_attn import (
flash_attn_varlen_func, flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func, flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
) )
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size from flash_attn.flash_attn_interface import _get_block_size
...@@ -1465,6 +1466,95 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): ...@@ -1465,6 +1466,95 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("num_splits", [1, 0])
# @pytest.mark.parametrize("num_splits", [0])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mqa"])
@pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 1024),
(16, 128 * 1024),
(128, 128),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num_splits, dtype):
if seqlen_q > seqlen_k and new_kv:
pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
nheads = 6
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
if new_kv:
k = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
else:
k, v = None, None
k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
cache_seqlens = torch.randint(0, (seqlen_k - seqlen_q + 1) if new_kv else (seqlen_k + 1), (batch_size, ), dtype=torch.int32, device=device)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
# k_cache[:, 64:] = -1
k_cache_ref = k_cache.clone()
v_cache_ref = v_cache.clone()
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
if new_kv:
update_mask = torch.logical_and(cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_q)
k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
out = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_q if new_kv else 0)
out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal)
out_pt, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal,
upcast=False, reorder_ops=True)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
if new_kv:
assert torch.equal(k_cache, k_cache_ref)
assert torch.equal(v_cache, v_cache_ref)
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("dtype", [torch.float16])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment