Commit 2406f288 authored by Tri Dao's avatar Tri Dao
Browse files

Enable headdim 256 backward on consumer GPUs (Ampere, Ada)

parent 43950dda
...@@ -70,7 +70,7 @@ FlashAttention-2 currently supports: ...@@ -70,7 +70,7 @@ FlashAttention-2 currently supports:
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
GPUs for now. GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). 2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. 3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
## How to use FlashAttention ## How to use FlashAttention
......
...@@ -783,8 +783,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -783,8 +783,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
if (head_size > 192) { if (head_size > 192 && (head_size <= 224 || is_dropout)) {
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800"); TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
} }
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
...@@ -1020,8 +1020,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1020,8 +1020,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
if (head_size > 192) { if (head_size > 192 && (head_size <= 224 || is_dropout)) {
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800"); TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
} }
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
......
...@@ -521,7 +521,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -521,7 +521,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value. // Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2); flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (Is_dropout) { if constexpr (Is_dropout) {
int warp_id = tidx / 32; int warp_id = tidx / 32;
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
......
...@@ -296,8 +296,12 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -296,8 +296,12 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 176 * 1024) { // H100 if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else { // A100, we don't do double buffering to save smem } else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
if constexpr (!Is_dropout) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
}
} }
}); });
} }
...@@ -231,9 +231,11 @@ struct Flash_bwd_kernel_traits : public Base { ...@@ -231,9 +231,11 @@ struct Flash_bwd_kernel_traits : public Base {
// TODO: generalize to other values of kBlockN // TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// static constexpr int kPBlockN = kBlockN; // static constexpr int kPBlockN = kBlockN;
static_assert(kBlockN >= 64); // Temporarily disabling this for hdim 256 on sm86 and sm89
// static_assert(kBlockN >= 64);
static_assert(kBlockN >= 32);
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
static constexpr int kPBlockN = 64; static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
static constexpr int kSwizzlePdS = 3; static constexpr int kSwizzlePdS = 3;
......
...@@ -664,7 +664,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ ...@@ -664,7 +664,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
# do_o = (g.float() * out.float()).sum(-1) # do_o = (g.float() * out.float()).sum(-1)
# dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64]) # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
# dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:]) # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
(dqkv,) = torch.autograd.grad(out, qkv, g) (dqkv,) = torch.autograd.grad(out, qkv, g)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
...@@ -687,7 +687,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ ...@@ -687,7 +687,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
if not alibi: if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
...@@ -811,7 +811,7 @@ def test_flash_attn_varlen_qkvpacked( ...@@ -811,7 +811,7 @@ def test_flash_attn_varlen_qkvpacked(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad) dqkv = dqkv_pad_fn(dqkv_unpad)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
...@@ -835,7 +835,7 @@ def test_flash_attn_varlen_qkvpacked( ...@@ -835,7 +835,7 @@ def test_flash_attn_varlen_qkvpacked(
if not alibi: if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
...@@ -1036,7 +1036,7 @@ def test_flash_attn_output( ...@@ -1036,7 +1036,7 @@ def test_flash_attn_output(
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if kvpacked: if kvpacked:
( (
dq, dq,
...@@ -1092,7 +1092,7 @@ def test_flash_attn_output( ...@@ -1092,7 +1092,7 @@ def test_flash_attn_output(
if not alibi: if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
...@@ -1339,7 +1339,7 @@ def test_flash_attn_varlen_output( ...@@ -1339,7 +1339,7 @@ def test_flash_attn_varlen_output(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if kvpacked: if kvpacked:
( (
dq_unpad, dq_unpad,
...@@ -1398,7 +1398,7 @@ def test_flash_attn_varlen_output( ...@@ -1398,7 +1398,7 @@ def test_flash_attn_varlen_output(
if not alibi: if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
...@@ -1476,7 +1476,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ...@@ -1476,7 +1476,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
( (
dq, dq,
dk, dk,
...@@ -1509,7 +1509,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ...@@ -1509,7 +1509,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
# of a Pytorch implementation. # of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
...@@ -1625,7 +1625,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1625,7 +1625,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
( (
dq_unpad, dq_unpad,
dk_unpad, dk_unpad,
...@@ -1661,7 +1661,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1661,7 +1661,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
# of a Pytorch implementation. # of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
...@@ -1755,7 +1755,7 @@ def test_flash_attn_splitkv( ...@@ -1755,7 +1755,7 @@ def test_flash_attn_splitkv(
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
( (
dq, dq,
dk, dk,
...@@ -1789,7 +1789,7 @@ def test_flash_attn_splitkv( ...@@ -1789,7 +1789,7 @@ def test_flash_attn_splitkv(
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
mult = 2 if not alibi else 8 mult = 2 if not alibi else 8
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
...@@ -1815,8 +1815,9 @@ def test_flash_attn_splitkv( ...@@ -1815,8 +1815,9 @@ def test_flash_attn_splitkv(
# @pytest.mark.parametrize("rotary_interleaved", [False]) # @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) @pytest.mark.parametrize("paged_kv_block_size", [None, 256])
@pytest.mark.parametrize("paged_kv_block_size", [256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [256])
@pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [False]) # @pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
...@@ -1900,12 +1901,13 @@ def test_flash_attn_kvcache( ...@@ -1900,12 +1901,13 @@ def test_flash_attn_kvcache(
b=batch_size, b=batch_size,
) )
k_cache = rearrange( k_cache = rearrange(
k_cache_paged[block_table.flatten()], # pytorch 1.12 doesn't have indexing with int32
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...", "(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size, b=batch_size,
)[:, :seqlen_k] )[:, :seqlen_k]
v_cache = rearrange( v_cache = rearrange(
v_cache_paged[block_table.flatten()], v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...", "(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size, b=batch_size,
)[:, :seqlen_k] )[:, :seqlen_k]
...@@ -1972,8 +1974,12 @@ def test_flash_attn_kvcache( ...@@ -1972,8 +1974,12 @@ def test_flash_attn_kvcache(
cos, sin = None, None cos, sin = None, None
q_ro, k_ro = q, k q_ro, k_ro = q, k
# k_cache[:, 64:] = -1 # k_cache[:, 64:] = -1
k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() k_cache_ref = (
v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
v_cache_ref = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
if new_kv: if new_kv:
update_mask = torch.logical_and( update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
...@@ -2044,16 +2050,20 @@ def test_flash_attn_kvcache( ...@@ -2044,16 +2050,20 @@ def test_flash_attn_kvcache(
# of a Pytorch implementation. # of a Pytorch implementation.
if new_kv: if new_kv:
if paged_kv_block_size is None: if paged_kv_block_size is None:
k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx] k_cache_select = (
v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx] k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
)
v_cache_select = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
)
else: else:
k_cache_select = rearrange( k_cache_select = rearrange(
k_cache_paged[block_table.flatten()], k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...", "(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size, b=batch_size,
)[:, :seqlen_k] )[:, :seqlen_k]
v_cache_select = rearrange( v_cache_select = rearrange(
v_cache_paged[block_table.flatten()], v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...", "(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size, b=batch_size,
)[:, :seqlen_k] )[:, :seqlen_k]
...@@ -2104,7 +2114,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty ...@@ -2104,7 +2114,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
torch.random.manual_seed(42) torch.random.manual_seed(42)
out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
g = torch.randn_like(out0) g = torch.randn_like(out0)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
( (
dq0, dq0,
dk0, dk0,
...@@ -2119,7 +2129,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty ...@@ -2119,7 +2129,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
assert torch.equal(out, out0) assert torch.equal(out, out0)
assert torch.equal(lse, lse0) assert torch.equal(lse, lse0)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
( (
dq, dq,
dk, dk,
...@@ -2326,7 +2336,7 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc ...@@ -2326,7 +2336,7 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
g = torch.randn_like(out) g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
for _ in range(50): for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
...@@ -2414,7 +2424,7 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus ...@@ -2414,7 +2424,7 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
) )
g = torch.randn_like(out) g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
for _ in range(50): for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment