cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
...
...
@@ -279,6 +309,11 @@ Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for
Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
### 2.5: Paged KV cache.
Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
Thanks to @beginlner for this contribution.
## Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
// number of times random will be generated per thread, to offset philox counter in thc random
...
...
@@ -1194,14 +1195,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor&q,// batch_size x seqlen_q x num_heads x head_size
constat::Tensor&kcache,// batch_size_c x seqlen_k x num_heads_k x head_size
constat::Tensor&vcache,// batch_size_c x seqlen_k x num_heads_k x head_size
constat::Tensor&kcache,// batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
constat::Tensor&vcache,// batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<constat::Tensor>&k_,// batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<constat::Tensor>&v_,// batch_size x seqlen_knew x num_heads_k x head_size