Unverified Commit ce3e7280 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by GitHub
Browse files

Allow varlen_fwd to take optional seqused_k (#647)


Co-authored-by: default avatarbottler <bottler@users.noreply.github.com>
parent 23b77c81
...@@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params &params,
at::Tensor out, at::Tensor out,
void *cu_seqlens_q_d, void *cu_seqlens_q_d,
void *cu_seqlens_k_d, void *cu_seqlens_k_d,
void *seqused_k,
void *p_d, void *p_d,
void *softmax_lse_d, void *softmax_lse_d,
float p_dropout, float p_dropout,
...@@ -72,6 +73,7 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -72,6 +73,7 @@ void set_params_fprop(Flash_fwd_params &params,
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d); params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d); params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
params.seqused_k = static_cast<int *>(seqused_k);
// P = softmax(QK^T) // P = softmax(QK^T)
params.p_ptr = p_d; params.p_ptr = p_d;
...@@ -156,6 +158,7 @@ void set_params_dgrad(Flash_bwd_params &params, ...@@ -156,6 +158,7 @@ void set_params_dgrad(Flash_bwd_params &params,
cu_seqlens_q_d, cu_seqlens_q_d,
cu_seqlens_k_d, cu_seqlens_k_d,
nullptr, nullptr,
nullptr,
softmax_lse_d, softmax_lse_d,
p_dropout, p_dropout,
softmax_scale, softmax_scale,
...@@ -363,6 +366,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -363,6 +366,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
q_padded, k_padded, v_padded, out, q_padded, k_padded, v_padded, out,
/*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
return_softmax ? p.data_ptr() : nullptr, return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(), softmax_lse.data_ptr(),
p_dropout, p_dropout,
...@@ -436,6 +440,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -436,6 +440,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const int max_seqlen_q, const int max_seqlen_q,
const int max_seqlen_k, const int max_seqlen_k,
const float p_dropout, const float p_dropout,
...@@ -494,6 +499,13 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -494,6 +499,13 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
if (seqused_k.has_value()){
auto seqused_k_ = seqused_k.value();
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
CHECK_SHAPE(seqused_k_, batch_size);
}
at::Tensor q_padded, k_padded, v_padded; at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) { if (head_size_og % 8 != 0) {
...@@ -554,6 +566,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -554,6 +566,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
q_padded, k_padded, v_padded, out, q_padded, k_padded, v_padded, out,
cu_seqlens_q.data_ptr(), cu_seqlens_q.data_ptr(),
cu_seqlens_k.data_ptr(), cu_seqlens_k.data_ptr(),
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
return_softmax ? p.data_ptr() : nullptr, return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(), softmax_lse.data_ptr(),
p_dropout, p_dropout,
...@@ -1167,6 +1180,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1167,6 +1180,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
q_padded, kcache_padded, vcache_padded, out, q_padded, kcache_padded, vcache_padded, out,
/*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
/*p_ptr=*/nullptr, /*p_ptr=*/nullptr,
softmax_lse.data_ptr(), softmax_lse.data_ptr(),
/*p_dropout=*/0.f, /*p_dropout=*/0.f,
......
...@@ -19,7 +19,7 @@ struct BlockInfo { ...@@ -19,7 +19,7 @@ struct BlockInfo {
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{ {
} }
......
...@@ -77,6 +77,9 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -77,6 +77,9 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k; int * __restrict__ cu_seqlens_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
int *__restrict__ blockmask; int *__restrict__ blockmask;
// The K_new and V_new matrices. // The K_new and V_new matrices.
......
...@@ -83,6 +83,7 @@ def _flash_attn_varlen_forward( ...@@ -83,6 +83,7 @@ def _flash_attn_varlen_forward(
None, None,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
None,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
dropout_p, dropout_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment