Unverified Commit af01244d authored by Grigory Sizov's avatar Grigory Sizov Committed by GitHub
Browse files

Add split-kv and M<->H swap to varlen forward decoding attention (#754)

* Add split-k, M<->H to varseq path

* skip M<->H when dropout>0, fix LSE
parent d8aacc51
......@@ -42,7 +42,8 @@ void set_params_fprop(Flash_fwd_params &params,
float p_dropout,
float softmax_scale,
int window_size_left,
int window_size_right) {
int window_size_right,
bool seqlenq_ngroups_swapped=false) {
// Reset the parameters
memset(&params, 0, sizeof(params));
......@@ -69,6 +70,10 @@ void set_params_fprop(Flash_fwd_params &params,
params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0);
params.o_batch_stride = out.stride(0);
if (seqlenq_ngroups_swapped) {
params.q_batch_stride *= seqlen_q;
params.o_batch_stride *= seqlen_q;
}
}
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
......@@ -251,6 +256,31 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
return 1;
}
void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
const int head_size_rounded, float p_dropout, const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
params.num_splits = num_splits;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
}
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
}
}
std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
......@@ -382,23 +412,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
window_size_left,
window_size_right);
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
params.num_splits = 1;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
}
set_params_splitkv(params, batch_size, num_heads,
head_size, seqlen_k, seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
......@@ -454,7 +471,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
}
std::vector<at::Tensor>
mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
......@@ -462,18 +479,17 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
......@@ -505,12 +521,30 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const auto sizes = q.sizes();
const int total_q = sizes[0];
const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = sizes[1];
int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int total_k = k.size(0);
const int num_heads_k = k.size(1);
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }
void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
max_seqlen_q = ngroups;
num_heads = num_heads_k;
cu_seqlens_q_d = nullptr;
}
const int total_q = q.sizes()[0];
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
......@@ -588,7 +622,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
num_heads, num_heads_k,
head_size, head_size_rounded,
q_padded, k_padded, v_padded, out,
cu_seqlens_q.data_ptr(),
cu_seqlens_q_d,
cu_seqlens_k.data_ptr(),
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
return_softmax ? p.data_ptr() : nullptr,
......@@ -596,7 +630,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
seqlenq_ngroups_swapped);
if (seqlenq_ngroups_swapped) {
// Only apply split-k for decoding
set_params_splitkv(params, batch_size, num_heads,
head_size, max_seqlen_k, max_seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
}
// number of times random will be generated per thread, to offset philox counter in thc random
// state
......@@ -642,6 +683,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
if (out_.has_value()) { out_.value().copy_(out); }
}
if (seqlenq_ngroups_swapped) {
long size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
long size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}
......@@ -1367,23 +1417,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
}
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
params.num_splits = num_splits;
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
set_params_splitkv(params, batch_size, num_heads,
head_size, seqlen_k, seqlen_q,
head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment