Commit 751c762c authored by Tri Dao's avatar Tri Dao
Browse files

Don't specialize for hdim 224 to speed up compilation

parent 1c275eb0
...@@ -443,7 +443,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -443,7 +443,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8); const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32); const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
...@@ -679,7 +679,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -679,7 +679,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8); const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32); const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
...@@ -874,17 +874,18 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -874,17 +874,18 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
if (head_size > 192 && (head_size <= 224 || is_dropout)) { if (head_size > 192 && is_dropout) {
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800");
} }
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32); const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; }
...@@ -1114,13 +1115,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1114,13 +1115,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
if (head_size > 192 && (head_size <= 224 || is_dropout)) { if (head_size > 192 && is_dropout) {
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800");
} }
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32); const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
...@@ -1415,7 +1417,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1415,7 +1417,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8); const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32); const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
......
...@@ -299,14 +299,6 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -299,14 +299,6 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
}); });
} }
template<typename T>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
});
}
template<typename T> template<typename T>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256; constexpr static int Headdim = 256;
......
...@@ -299,33 +299,6 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -299,33 +299,6 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
}); });
} }
template<typename T, bool Is_causal>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
// If we have N = 32, there are only 1024 elements to load at once, where each load
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
template<typename T, bool Is_causal> template<typename T, bool Is_causal>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256; constexpr static int Headdim = 256;
......
...@@ -15,7 +15,7 @@ DTYPE_MAP = { ...@@ -15,7 +15,7 @@ DTYPE_MAP = {
} }
SM = [80] # Sm80 kernels support up to SM = [80] # Sm80 kernels support up to
HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256] HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 256]
IS_CAUSAL = ["false", "true"] IS_CAUSAL = ["false", "true"]
KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h" KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
......
...@@ -107,9 +107,6 @@ ...@@ -107,9 +107,6 @@
} else if (HEADDIM <= 192) { \ } else if (HEADDIM <= 192) { \
constexpr static int kHeadDim = 192; \ constexpr static int kHeadDim = 192; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (HEADDIM <= 224) { \
constexpr static int kHeadDim = 224; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 256) { \ } else if (HEADDIM <= 256) { \
constexpr static int kHeadDim = 256; \ constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
......
...@@ -192,8 +192,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM: ...@@ -192,8 +192,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
...@@ -208,8 +206,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM: ...@@ -208,8 +206,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
...@@ -224,8 +220,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM: ...@@ -224,8 +220,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
...@@ -240,8 +234,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM: ...@@ -240,8 +234,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
...@@ -256,8 +248,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM: ...@@ -256,8 +248,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
], ],
......
...@@ -682,7 +682,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ ...@@ -682,7 +682,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
# do_o = (g.float() * out.float()).sum(-1) # do_o = (g.float() * out.float()).sum(-1)
# dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64]) # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
# dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:]) # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
(dqkv,) = torch.autograd.grad(out, qkv, g) (dqkv,) = torch.autograd.grad(out, qkv, g)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
...@@ -705,7 +705,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ ...@@ -705,7 +705,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
if not alibi: if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
...@@ -829,7 +829,7 @@ def test_flash_attn_varlen_qkvpacked( ...@@ -829,7 +829,7 @@ def test_flash_attn_varlen_qkvpacked(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad) dqkv = dqkv_pad_fn(dqkv_unpad)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
...@@ -853,7 +853,7 @@ def test_flash_attn_varlen_qkvpacked( ...@@ -853,7 +853,7 @@ def test_flash_attn_varlen_qkvpacked(
if not alibi: if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
...@@ -866,9 +866,9 @@ def test_flash_attn_varlen_qkvpacked( ...@@ -866,9 +866,9 @@ def test_flash_attn_varlen_qkvpacked(
@pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True]) # @pytest.mark.parametrize("alibi", [False])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True]) # @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
...@@ -894,7 +894,7 @@ def test_flash_attn_varlen_qkvpacked( ...@@ -894,7 +894,7 @@ def test_flash_attn_varlen_qkvpacked(
) )
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.17]) # @pytest.mark.parametrize("dropout_p", [0.0])
@pytest.mark.parametrize("softcap", [0.0, 50.0]) @pytest.mark.parametrize("softcap", [0.0, 50.0])
def test_flash_attn_output( def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
...@@ -1066,7 +1066,7 @@ def test_flash_attn_output( ...@@ -1066,7 +1066,7 @@ def test_flash_attn_output(
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0: if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
if kvpacked: if kvpacked:
( (
dq, dq,
...@@ -1122,10 +1122,10 @@ def test_flash_attn_output( ...@@ -1122,10 +1122,10 @@ def test_flash_attn_output(
if not alibi: if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)): if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.parametrize("kvpacked", [True, False]) @pytest.mark.parametrize("kvpacked", [True, False])
...@@ -1382,7 +1382,7 @@ def test_flash_attn_varlen_output( ...@@ -1382,7 +1382,7 @@ def test_flash_attn_varlen_output(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)): if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)):
if kvpacked: if kvpacked:
( (
dq_unpad, dq_unpad,
...@@ -1441,7 +1441,7 @@ def test_flash_attn_varlen_output( ...@@ -1441,7 +1441,7 @@ def test_flash_attn_varlen_output(
if not alibi: if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0: if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
...@@ -1519,7 +1519,6 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ...@@ -1519,7 +1519,6 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
( (
dq, dq,
dk, dk,
...@@ -1552,7 +1551,6 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ...@@ -1552,7 +1551,6 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
# of a Pytorch implementation. # of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
...@@ -1684,7 +1682,7 @@ def test_flash_attn_varlen_causal( ...@@ -1684,7 +1682,7 @@ def test_flash_attn_varlen_causal(
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None test_backward = block_table is None
if test_backward: if test_backward:
( (
dq_unpad, dq_unpad,
...@@ -1815,7 +1813,6 @@ def test_flash_attn_splitkv( ...@@ -1815,7 +1813,6 @@ def test_flash_attn_splitkv(
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
( (
dq, dq,
dk, dk,
...@@ -1849,7 +1846,6 @@ def test_flash_attn_splitkv( ...@@ -1849,7 +1846,6 @@ def test_flash_attn_splitkv(
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
mult = 2 if not alibi else 8 mult = 2 if not alibi else 8
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
...@@ -2208,7 +2204,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty ...@@ -2208,7 +2204,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
torch.random.manual_seed(42) torch.random.manual_seed(42)
out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
g = torch.randn_like(out0) g = torch.randn_like(out0)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
( (
dq0, dq0,
dk0, dk0,
...@@ -2223,7 +2219,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty ...@@ -2223,7 +2219,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
assert torch.equal(out, out0) assert torch.equal(out, out0)
assert torch.equal(lse, lse0) assert torch.equal(lse, lse0)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
( (
dq, dq,
dk, dk,
...@@ -2430,7 +2426,6 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc ...@@ -2430,7 +2426,6 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
g = torch.randn_like(out) g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
for _ in range(50): for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
...@@ -2518,7 +2513,6 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus ...@@ -2518,7 +2513,6 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
) )
g = torch.randn_like(out) g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
for _ in range(50): for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment