Commit 40e534a7 authored by Tri Dao's avatar Tri Dao
Browse files

Implement cache_leftpad

parent 116b05f9
...@@ -532,6 +532,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -532,6 +532,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q, int max_seqlen_q,
...@@ -731,6 +732,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -731,6 +732,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
} }
if (leftpad_k_.has_value()) {
auto leftpad_k = leftpad_k_.value();
TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
CHECK_DEVICE(leftpad_k);
CHECK_CONTIGUOUS(leftpad_k);
CHECK_SHAPE(leftpad_k, batch_size);
params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
}
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
// We use a custom RNG that increases the offset by batch_size * nheads * 32. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
...@@ -1279,6 +1290,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1279,6 +1290,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
...@@ -1469,6 +1481,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1469,6 +1481,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr()); params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
} }
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value()); params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
if (leftpad_k_.has_value()) {
TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
auto leftpad_k = leftpad_k_.value();
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
CHECK_DEVICE(leftpad_k);
CHECK_CONTIGUOUS(leftpad_k);
CHECK_SHAPE(leftpad_k, batch_size);
params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
}
if (rotary_cos_.has_value()) { if (rotary_cos_.has_value()) {
TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
......
...@@ -18,8 +18,9 @@ struct BlockInfo { ...@@ -18,8 +18,9 @@ struct BlockInfo {
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{ {
} }
...@@ -30,13 +31,14 @@ struct BlockInfo { ...@@ -30,13 +31,14 @@ struct BlockInfo {
template <typename index_t> template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
} }
const int sum_s_q; const int sum_s_q;
const int sum_s_k; const int sum_s_k;
const int actual_seqlen_q; const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int leftpad_k;
const int seqlen_k_cache; const int seqlen_k_cache;
const int actual_seqlen_k; const int actual_seqlen_k;
}; };
......
...@@ -76,6 +76,7 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -76,6 +76,7 @@ struct Flash_fwd_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence. // array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k; int * __restrict__ cu_seqlens_k;
int * __restrict__ leftpad_k;
// If provided, the actual length of each k sequence. // If provided, the actual length of each k sequence.
int * __restrict__ seqused_k; int * __restrict__ seqused_k;
......
...@@ -690,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -690,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{}, Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
make_stride(params.rotary_dim / 2, _1{})); make_stride(params.rotary_dim / 2, _1{}));
...@@ -711,9 +711,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -711,9 +711,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
const index_t row_offset_knew = bidb * params.knew_batch_stride
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
const index_t row_offset_vnew = bidb * params.vnew_batch_stride
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
...@@ -791,7 +793,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -791,7 +793,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
} else { } else {
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0. // We do this by setting the row stride of gCos / gSin to 0.
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
......
...@@ -81,7 +81,8 @@ def _flash_attn_varlen_forward( ...@@ -81,7 +81,8 @@ def _flash_attn_varlen_forward(
softcap, softcap,
alibi_slopes, alibi_slopes,
return_softmax, return_softmax,
block_table, block_table=None,
leftpad_k=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
...@@ -93,6 +94,7 @@ def _flash_attn_varlen_forward( ...@@ -93,6 +94,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
None, None,
leftpad_k,
block_table, block_table,
alibi_slopes, alibi_slopes,
max_seqlen_q, max_seqlen_q,
...@@ -1150,6 +1152,7 @@ def flash_attn_with_kvcache( ...@@ -1150,6 +1152,7 @@ def flash_attn_with_kvcache(
rotary_sin=None, rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None, cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
...@@ -1217,11 +1220,12 @@ def flash_attn_with_kvcache( ...@@ -1217,11 +1220,12 @@ def flash_attn_with_kvcache(
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache. KV cache.
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices. might come from any of the duplicate indices.
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
...@@ -1269,6 +1273,7 @@ def flash_attn_with_kvcache( ...@@ -1269,6 +1273,7 @@ def flash_attn_with_kvcache(
rotary_cos, rotary_cos,
rotary_sin, rotary_sin,
cache_batch_idx, cache_batch_idx,
cache_leftpad,
block_table, block_table,
alibi_slopes, alibi_slopes,
None, None,
......
...@@ -182,9 +182,14 @@ def construct_local_mask( ...@@ -182,9 +182,14 @@ def construct_local_mask(
query_padding_mask=None, query_padding_mask=None,
key_padding_mask=None, key_padding_mask=None,
device=None, device=None,
key_leftpad=None,
): ):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
if key_leftpad is not None:
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
sk = ( sk = (
seqlen_k seqlen_k
if key_padding_mask is None if key_padding_mask is None
...@@ -219,6 +224,7 @@ def attention_ref( ...@@ -219,6 +224,7 @@ def attention_ref(
softcap=0.0, softcap=0.0,
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
key_leftpad=None,
): ):
""" """
Arguments: Arguments:
...@@ -268,6 +274,7 @@ def attention_ref( ...@@ -268,6 +274,7 @@ def attention_ref(
query_padding_mask, query_padding_mask,
key_padding_mask, key_padding_mask,
q.device, q.device,
key_leftpad=key_leftpad,
) )
scores.masked_fill_(local_mask, float("-inf")) scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None: if attn_bias is not None:
...@@ -306,6 +313,7 @@ def attention_kvpacked_ref( ...@@ -306,6 +313,7 @@ def attention_kvpacked_ref(
softcap=0.0, softcap=0.0,
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
key_leftpad=None,
): ):
return attention_ref( return attention_ref(
q, q,
...@@ -321,6 +329,7 @@ def attention_kvpacked_ref( ...@@ -321,6 +329,7 @@ def attention_kvpacked_ref(
window_size=window_size, window_size=window_size,
softcap=softcap, softcap=softcap,
reorder_ops=reorder_ops, reorder_ops=reorder_ops,
key_leftpad=key_leftpad,
) )
...@@ -1868,9 +1877,11 @@ def test_flash_attn_splitkv( ...@@ -1868,9 +1877,11 @@ def test_flash_attn_splitkv(
# @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) @pytest.mark.parametrize("paged_kv_block_size", [None, 256])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [256]) # @pytest.mark.parametrize("paged_kv_block_size", [None])
@pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_leftpad", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [False]) # @pytest.mark.parametrize("has_leftpad", [True])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
...@@ -1898,6 +1909,7 @@ def test_flash_attn_kvcache( ...@@ -1898,6 +1909,7 @@ def test_flash_attn_kvcache(
seqlen_k, seqlen_k,
d, d,
has_batch_idx, has_batch_idx,
has_leftpad,
paged_kv_block_size, paged_kv_block_size,
rotary_fraction, rotary_fraction,
rotary_interleaved, rotary_interleaved,
...@@ -1916,6 +1928,8 @@ def test_flash_attn_kvcache( ...@@ -1916,6 +1928,8 @@ def test_flash_attn_kvcache(
pytest.skip() pytest.skip()
if has_batch_idx and paged_kv_block_size is not None: if has_batch_idx and paged_kv_block_size is not None:
pytest.skip() pytest.skip()
if has_leftpad and paged_kv_block_size is not None:
pytest.skip()
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
...@@ -1961,9 +1975,19 @@ def test_flash_attn_kvcache( ...@@ -1961,9 +1975,19 @@ def test_flash_attn_kvcache(
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
if has_leftpad:
cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
for i in range(batch_size)])
else:
cache_leftpad = None
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
if has_leftpad:
key_padding_mask = torch.logical_and(
key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
)
if has_batch_idx: if has_batch_idx:
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
:batch_size :batch_size
...@@ -2038,6 +2062,7 @@ def test_flash_attn_kvcache( ...@@ -2038,6 +2062,7 @@ def test_flash_attn_kvcache(
rotary_sin=sin, rotary_sin=sin,
cache_seqlens=cache_seqlens, cache_seqlens=cache_seqlens,
cache_batch_idx=cache_batch_idx, cache_batch_idx=cache_batch_idx,
cache_leftpad=cache_leftpad,
block_table=block_table, block_table=block_table,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
...@@ -2066,6 +2091,7 @@ def test_flash_attn_kvcache( ...@@ -2066,6 +2091,7 @@ def test_flash_attn_kvcache(
None, None,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
key_leftpad=cache_leftpad,
) )
out_pt, _ = attention_ref( out_pt, _ = attention_ref(
q_ro, q_ro,
...@@ -2080,6 +2106,7 @@ def test_flash_attn_kvcache( ...@@ -2080,6 +2106,7 @@ def test_flash_attn_kvcache(
window_size=window_size, window_size=window_size,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
key_leftpad=cache_leftpad,
) )
print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment