"src/webui/vscode:/vscode.git/clone" did not exist on "3fe117f0739184714f05340dce950c5407d6d379"
Unverified Commit af01244d authored by Grigory Sizov's avatar Grigory Sizov Committed by GitHub
Browse files

Add split-kv and M<->H swap to varlen forward decoding attention (#754)

* Add split-k, M<->H to varseq path

* skip M<->H when dropout>0, fix LSE
parent d8aacc51
...@@ -42,7 +42,8 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -42,7 +42,8 @@ void set_params_fprop(Flash_fwd_params &params,
float p_dropout, float p_dropout,
float softmax_scale, float softmax_scale,
int window_size_left, int window_size_left,
int window_size_right) { int window_size_right,
bool seqlenq_ngroups_swapped=false) {
// Reset the parameters // Reset the parameters
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));
...@@ -69,6 +70,10 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -69,6 +70,10 @@ void set_params_fprop(Flash_fwd_params &params,
params.k_batch_stride = k.stride(0); params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0); params.v_batch_stride = v.stride(0);
params.o_batch_stride = out.stride(0); params.o_batch_stride = out.stride(0);
if (seqlenq_ngroups_swapped) {
params.q_batch_stride *= seqlen_q;
params.o_batch_stride *= seqlen_q;
}
} }
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d); params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
...@@ -251,6 +256,31 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n ...@@ -251,6 +256,31 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
return 1; return 1;
} }
void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
const int head_size_rounded, float p_dropout, const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
params.num_splits = num_splits;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
}
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
}
}
std::vector<at::Tensor> std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
...@@ -382,23 +412,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -382,23 +412,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
window_size_left, window_size_left,
window_size_right); window_size_right);
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); set_params_splitkv(params, batch_size, num_heads,
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; head_size, seqlen_k, seqlen_q,
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
params.num_splits = 1;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
}
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
...@@ -454,7 +471,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -454,7 +471,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
} }
std::vector<at::Tensor> std::vector<at::Tensor>
mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
...@@ -462,18 +479,17 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -462,18 +479,17 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q, int max_seqlen_q,
const int max_seqlen_k, const int max_seqlen_k,
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, bool is_causal,
int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
...@@ -505,12 +521,30 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -505,12 +521,30 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const auto sizes = q.sizes(); const auto sizes = q.sizes();
const int total_q = sizes[0];
const int batch_size = cu_seqlens_q.numel() - 1; const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = sizes[1]; int num_heads = sizes[1];
const int head_size_og = sizes[2]; const int head_size_og = sizes[2];
const int total_k = k.size(0); const int total_k = k.size(0);
const int num_heads_k = k.size(1); const int num_heads_k = k.size(1);
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }
void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
max_seqlen_q = ngroups;
num_heads = num_heads_k;
cu_seqlens_q_d = nullptr;
}
const int total_q = q.sizes()[0];
TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
...@@ -588,7 +622,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -588,7 +622,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
num_heads, num_heads_k, num_heads, num_heads_k,
head_size, head_size_rounded, head_size, head_size_rounded,
q_padded, k_padded, v_padded, out, q_padded, k_padded, v_padded, out,
cu_seqlens_q.data_ptr(), cu_seqlens_q_d,
cu_seqlens_k.data_ptr(), cu_seqlens_k.data_ptr(),
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
return_softmax ? p.data_ptr() : nullptr, return_softmax ? p.data_ptr() : nullptr,
...@@ -596,7 +630,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -596,7 +630,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
p_dropout, p_dropout,
softmax_scale, softmax_scale,
window_size_left, window_size_left,
window_size_right); window_size_right,
seqlenq_ngroups_swapped);
if (seqlenq_ngroups_swapped) {
// Only apply split-k for decoding
set_params_splitkv(params, batch_size, num_heads,
head_size, max_seqlen_k, max_seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
}
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
...@@ -642,6 +683,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -642,6 +683,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
if (out_.has_value()) { out_.value().copy_(out); } if (out_.has_value()) { out_.value().copy_(out); }
} }
if (seqlenq_ngroups_swapped) {
long size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
long size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
} }
...@@ -1367,23 +1417,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1367,23 +1417,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr()); params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
} }
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); set_params_splitkv(params, batch_size, num_heads,
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; head_size, seqlen_k, seqlen_q,
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
params.num_splits = num_splits;
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
if (alibi_slopes_.has_value()) { if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value(); auto alibi_slopes = alibi_slopes_.value();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment