"src/vscode:/vscode.git/clone" did not exist on "0b6365374b967ebc9b6bea50dd1e23dc38bf3fc3"
Commit 54e80a38 authored by Tri Dao's avatar Tri Dao
Browse files

Implement page KV cache


Co-authored-by: default avatarljss <450993438@qq.com>
parent bdcae547
...@@ -148,6 +148,7 @@ def flash_attn_with_kvcache( ...@@ -148,6 +148,7 @@ def flash_attn_with_kvcache(
rotary_sin=None, rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None, cache_batch_idx: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
...@@ -160,6 +161,10 @@ def flash_attn_with_kvcache( ...@@ -160,6 +161,10 @@ def flash_attn_with_kvcache(
the previous step, and update them with the new keys/values from the current step, and do the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel. attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
...@@ -169,12 +174,36 @@ def flash_attn_with_kvcache( ...@@ -169,12 +174,36 @@ def flash_attn_with_kvcache(
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass. Note: Does not support backward pass.
Arguments: Arguments:
q: (batch_size, seqlen, nheads, headdim) q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
page_block_size must be a multiple of 256.
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens. k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
...@@ -183,6 +212,7 @@ def flash_attn_with_kvcache( ...@@ -183,6 +212,7 @@ def flash_attn_with_kvcache(
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache. KV cache.
block_table [optional]: (num_blocks, max_num_blocks_per_seq), dtype torch.int32.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache If the indices are not distinct, and k and v are provided, the values updated in the cache
...@@ -279,6 +309,11 @@ Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for ...@@ -279,6 +309,11 @@ Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for
Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution. Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
### 2.5: Paged KV cache.
Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
Thanks to @beginlner for this contribution.
## Performance ## Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
......
...@@ -257,8 +257,9 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n ...@@ -257,8 +257,9 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
} }
void set_params_splitkv(Flash_fwd_params &params, const int batch_size, void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
const int head_size_rounded, float p_dropout, const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) { const int head_size_rounded, const float p_dropout,
const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
// This needs to match with run_mha_fwd_splitkv_dispatch // This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
...@@ -635,8 +636,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -635,8 +636,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
// Only apply split-k for decoding // Only apply split-k for decoding
set_params_splitkv(params, batch_size, num_heads, set_params_splitkv(params, batch_size, num_heads,
head_size, max_seqlen_k, max_seqlen_q, head_size, max_seqlen_k, max_seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
} }
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
...@@ -1194,14 +1195,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1194,14 +1195,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
std::vector<at::Tensor> std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale, const float softmax_scale,
...@@ -1235,15 +1237,30 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1235,15 +1237,30 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
at::Tensor block_table;
const bool paged_KV = block_table_.has_value();
if (paged_KV) {
TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}
const auto sizes = q.sizes(); const auto sizes = q.sizes();
const int batch_size = sizes[0]; const int batch_size = sizes[0];
int seqlen_q = sizes[1]; int seqlen_q = sizes[1];
int num_heads = sizes[2]; int num_heads = sizes[2];
const int head_size_og = sizes[3]; const int head_size_og = sizes[3];
const int seqlen_k = kcache.size(1);
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : kcache.size(0);
const int page_block_size = !paged_KV ? 1 : kcache.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
const int num_heads_k = kcache.size(2); const int num_heads_k = kcache.size(2);
const int batch_size_c = kcache.size(0); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
...@@ -1266,8 +1283,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1266,8 +1283,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
if (window_size_right >= seqlen_k) { window_size_right = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); if (!paged_KV) {
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
} else {
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}
at::Tensor q_padded, kcache_padded, vcache_padded; at::Tensor q_padded, kcache_padded, vcache_padded;
if (head_size_og % 8 != 0) { if (head_size_og % 8 != 0) {
...@@ -1406,6 +1429,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1406,6 +1429,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
head_size, seqlen_k, seqlen_q, head_size, seqlen_k, seqlen_q,
head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts); head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
}
params.page_block_size = page_block_size;
if (alibi_slopes_.has_value()) { if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value(); auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
...@@ -1419,8 +1448,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1419,8 +1448,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
} }
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value()); // or paged KV cache
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
if (head_size_og % 8 != 0) { if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
......
...@@ -101,6 +101,11 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -101,6 +101,11 @@ struct Flash_fwd_params : public Qkv_params {
// The indices to index into the KV cache. // The indices to index into the KV cache.
int * __restrict__ cache_batch_idx; int * __restrict__ cache_batch_idx;
// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
// The dropout probability (probability of keeping an activation). // The dropout probability (probability of keeping an activation).
float p_dropout; float p_dropout;
// uint32_t p_dropout_in_uint; // uint32_t p_dropout_in_uint;
......
...@@ -559,10 +559,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -559,10 +559,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block. // We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_k = block_table == nullptr
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q), Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
...@@ -695,11 +702,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -695,11 +702,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
auto tKgK_data = tKgK.data();
auto tVgV_data = tVgV.data();
for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
flash::copy_w_min_idx<Is_even_K>( flash::copy_w_min_idx<Is_even_K>(
tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
); );
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
if (params.rotary_dim == 0) { if (params.rotary_dim == 0) {
flash::copy_w_min_idx<Is_even_K>( flash::copy_w_min_idx<Is_even_K>(
...@@ -725,15 +733,27 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -725,15 +733,27 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
} }
} }
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
if (n_block > n_block_copy_min) {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur;
tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
}
}
} }
// Need this before we can read in K again, so that we'll see the updated K values. // Need this before we can read in K again, so that we'll see the updated K values.
__syncthreads(); __syncthreads();
if (n_block_max > n_block_copy_min) { tKgK.data() = tKgK_data;
tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; tVgV.data() = tVgV_data;
tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride;
}
} }
// Read Q from gmem to smem, optionally apply rotary embedding. // Read Q from gmem to smem, optionally apply rotary embedding.
...@@ -812,7 +832,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -812,7 +832,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Advance gV // Advance gV
if (masking_step > 0) { if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
...@@ -839,7 +867,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -839,7 +867,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (n_block > n_block_min) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
...@@ -874,7 +910,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -874,7 +910,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
// Advance gV // Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -887,7 +931,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -887,7 +931,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
__syncthreads(); __syncthreads();
if (n_block > n_block_min) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
......
...@@ -1084,6 +1084,7 @@ def flash_attn_with_kvcache( ...@@ -1084,6 +1084,7 @@ def flash_attn_with_kvcache(
rotary_sin=None, rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None, cache_batch_idx: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
...@@ -1135,8 +1136,11 @@ def flash_attn_with_kvcache( ...@@ -1135,8 +1136,11 @@ def flash_attn_with_kvcache(
Arguments: Arguments:
q: (batch_size, seqlen, nheads, headdim) q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
page_block_size must be a multiple of 256.
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens. k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
...@@ -1145,6 +1149,7 @@ def flash_attn_with_kvcache( ...@@ -1145,6 +1149,7 @@ def flash_attn_with_kvcache(
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache. KV cache.
block_table [optional]: (num_blocks, max_num_blocks_per_seq), dtype torch.int32.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache If the indices are not distinct, and k and v are provided, the values updated in the cache
...@@ -1180,6 +1185,7 @@ def flash_attn_with_kvcache( ...@@ -1180,6 +1185,7 @@ def flash_attn_with_kvcache(
) )
cache_seqlens = maybe_contiguous(cache_seqlens) cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx) cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)
out, softmax_lse = flash_attn_cuda.fwd_kvcache( out, softmax_lse = flash_attn_cuda.fwd_kvcache(
q, q,
k_cache, k_cache,
...@@ -1190,6 +1196,7 @@ def flash_attn_with_kvcache( ...@@ -1190,6 +1196,7 @@ def flash_attn_with_kvcache(
rotary_cos, rotary_cos,
rotary_sin, rotary_sin,
cache_batch_idx, cache_batch_idx,
block_table,
alibi_slopes, alibi_slopes,
None, None,
softmax_scale, softmax_scale,
......
...@@ -708,7 +708,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ ...@@ -708,7 +708,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): def test_flash_attn_varlen_qkvpacked(
seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype
):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = "cuda" device = "cuda"
...@@ -1698,7 +1700,9 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1698,7 +1700,9 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
], ],
) )
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype): def test_flash_attn_splitkv(
seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype
):
if swap_sq_sk: if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda" device = "cuda"
...@@ -1800,7 +1804,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al ...@@ -1800,7 +1804,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al
@pytest.mark.parametrize("new_kv", [False, True]) @pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [False]) # @pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True]) # @pytest.mark.parametrize("alibi", [False])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
...@@ -1811,10 +1815,12 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al ...@@ -1811,10 +1815,12 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al
# @pytest.mark.parametrize("rotary_interleaved", [False]) # @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@pytest.mark.parametrize("paged_kv_block_size", [256, 512])
@pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [False]) # @pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128]) # @pytest.mark.parametrize("d", [128])
...@@ -1840,6 +1846,7 @@ def test_flash_attn_kvcache( ...@@ -1840,6 +1846,7 @@ def test_flash_attn_kvcache(
seqlen_k, seqlen_k,
d, d,
has_batch_idx, has_batch_idx,
paged_kv_block_size,
rotary_fraction, rotary_fraction,
rotary_interleaved, rotary_interleaved,
seqlen_new_eq_seqlen_q, seqlen_new_eq_seqlen_q,
...@@ -1855,6 +1862,8 @@ def test_flash_attn_kvcache( ...@@ -1855,6 +1862,8 @@ def test_flash_attn_kvcache(
pytest.skip() pytest.skip()
if not new_kv and rotary_fraction > 0.0: if not new_kv and rotary_fraction > 0.0:
pytest.skip() pytest.skip()
if has_batch_idx and paged_kv_block_size is not None:
pytest.skip()
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
...@@ -1873,10 +1882,35 @@ def test_flash_attn_kvcache( ...@@ -1873,10 +1882,35 @@ def test_flash_attn_kvcache(
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
else: else:
k, v = None, None k, v = None, None
k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) if paged_kv_block_size is None:
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
block_table = None
else:
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
k_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
v_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
block_table = rearrange(
torch.randperm(num_blocks, dtype=torch.int32, device=device),
"(b nblocks) -> b nblocks",
b=batch_size,
)
k_cache = rearrange(
k_cache_paged[block_table.flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[block_table.flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
cache_seqlens = torch.randint( cache_seqlens = torch.randint(
0, 0 if new_kv else 1,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
if new_kv if new_kv
...@@ -1903,7 +1937,15 @@ def test_flash_attn_kvcache( ...@@ -1903,7 +1937,15 @@ def test_flash_attn_kvcache(
alibi_slopes, attn_bias = None, None alibi_slopes, attn_bias = None, None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if rotary_dim > 0: if rotary_dim > 0:
angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi angle = (
torch.rand(
seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,
rotary_dim // 2,
device=device,
)
* 2
* math.pi
)
cos = torch.cos(angle).to(dtype=dtype) cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype)
if causal or local: if causal or local:
...@@ -1942,14 +1984,15 @@ def test_flash_attn_kvcache( ...@@ -1942,14 +1984,15 @@ def test_flash_attn_kvcache(
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
out = flash_attn_with_kvcache( out = flash_attn_with_kvcache(
q, q,
k_cache, k_cache if paged_kv_block_size is None else k_cache_paged,
v_cache, v_cache if paged_kv_block_size is None else v_cache_paged,
k, k,
v, v,
cos, rotary_cos=cos,
sin, rotary_sin=sin,
cache_seqlens, cache_seqlens=cache_seqlens,
cache_batch_idx, cache_batch_idx=cache_batch_idx,
block_table=block_table,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
rotary_interleaved=rotary_interleaved, rotary_interleaved=rotary_interleaved,
...@@ -2000,8 +2043,20 @@ def test_flash_attn_kvcache( ...@@ -2000,8 +2043,20 @@ def test_flash_attn_kvcache(
# Check that FlashAttention's numerical error is at most twice the numerical error # Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation. # of a Pytorch implementation.
if new_kv: if new_kv:
k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx] if paged_kv_block_size is None:
v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx] k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
else:
k_cache_select = rearrange(
k_cache_paged[block_table.flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache_select = rearrange(
v_cache_paged[block_table.flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
assert torch.equal(v_cache_select, v_cache_ref) assert torch.equal(v_cache_select, v_cache_ref)
mult = 3 if not alibi else 5 mult = 3 if not alibi else 5
...@@ -2280,8 +2335,6 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc ...@@ -2280,8 +2335,6 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
assert torch.equal(dq, dq0) assert torch.equal(dq, dq0)
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment