Unverified Commit d562aa63 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Sync with FA v2.6.0 to support soft capping (#13)

parent 12375706
......@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -16,17 +16,18 @@ DTYPE_MAP = {
SM = [80] # Sm80 kernels support up to
HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256]
IS_CAUSAL = ["false", "true"]
KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params &params, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
}}
"""
KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);
"""
KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
......@@ -43,13 +44,14 @@ class Kernel:
sm: int
dtype: str
head_dim: int
is_causal: bool
direction: str
@property
def template(self) -> str:
if self.direction == "fwd":
return KERNEL_IMPL_TEMPLATE_FWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
elif self.direction == "bwd":
return KERNEL_IMPL_TEMPLATE_BWD.format(
......@@ -57,18 +59,21 @@ class Kernel:
)
else:
return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
@property
def filename(self) -> str:
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu"
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu"
def get_all_kernels() -> List[Kernel]:
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
for direction in ["fwd", "bwd", "fwd_split"]:
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction)
for direction in ["fwd", "fwd_split"]:
for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
for direction in ["bwd"]:
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal="false", direction=direction)
def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
......
......@@ -4,7 +4,7 @@
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
......
......@@ -13,7 +13,7 @@ using namespace cute;
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
......@@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
......@@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
......
......@@ -4,7 +4,7 @@
#pragma once
#include <cute/algorithm/copy.hpp>
#include <cute/tensor.hpp>
#include "utils.h"
......
......@@ -56,6 +56,16 @@
#define EVENK_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define SOFTCAP_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_LOCAL
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
......
......@@ -14,8 +14,7 @@
#include <cuda_bf16.h>
#endif
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
......
......@@ -151,22 +151,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
......@@ -183,6 +183,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"] + generator_flag,
......@@ -203,6 +219,7 @@ if not SKIP_CUDA_BUILD:
# "-DFLASHATTENTION_DISABLE_BACKWARD",
"-DFLASHATTENTION_DISABLE_DROPOUT",
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
"-DFLASHATTENTION_DISABLE_UNEVEN_K",
# "-DFLASHATTENTION_DISABLE_LOCAL",
]
......
......@@ -216,6 +216,7 @@ def attention_ref(
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
......@@ -233,7 +234,7 @@ def attention_ref(
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
......@@ -253,6 +254,10 @@ def attention_ref(
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if softcap > 0:
scores /= softcap
scores = scores.tanh()
scores *= softcap
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
......@@ -298,6 +303,7 @@ def attention_kvpacked_ref(
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
......@@ -313,6 +319,7 @@ def attention_kvpacked_ref(
upcast=upcast,
causal=causal,
window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops,
)
......@@ -325,6 +332,7 @@ def attention_qkvpacked_ref(
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
......@@ -340,6 +348,7 @@ def attention_qkvpacked_ref(
upcast=upcast,
causal=causal,
window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops,
)
......@@ -877,23 +886,29 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.17])
@pytest.mark.parametrize("softcap", [0.0, 50.0])
def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if softcap > 0.0 and dropout_p > 0.0:
pytest.skip("Softcap and dropout not supported together")
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap
if kvpacked:
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
......@@ -918,6 +933,7 @@ def test_flash_attn_output(
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
......@@ -930,6 +946,7 @@ def test_flash_attn_output(
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
......@@ -984,6 +1001,7 @@ def test_flash_attn_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_kvpacked_ref(
q,
......@@ -995,6 +1013,7 @@ def test_flash_attn_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
......@@ -1010,6 +1029,7 @@ def test_flash_attn_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_ref(
q,
......@@ -1022,6 +1042,7 @@ def test_flash_attn_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
......@@ -1036,7 +1057,7 @@ def test_flash_attn_output(
g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if kvpacked:
(
dq,
......@@ -1092,7 +1113,7 @@ def test_flash_attn_output(
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
......@@ -1133,24 +1154,31 @@ def test_flash_attn_output(
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize("softcap", [0.0, 50.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if softcap > 0.0 and dropout_p > 0.0:
pytest.skip("Softcap and dropout not supported together")
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap
if kvpacked:
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
......@@ -1198,6 +1226,7 @@ def test_flash_attn_varlen_output(
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
......@@ -1229,6 +1258,7 @@ def test_flash_attn_varlen_output(
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
......@@ -1288,6 +1318,7 @@ def test_flash_attn_varlen_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_kvpacked_ref(
q,
......@@ -1299,6 +1330,7 @@ def test_flash_attn_varlen_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
......@@ -1314,6 +1346,7 @@ def test_flash_attn_varlen_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_ref(
q,
......@@ -1326,6 +1359,7 @@ def test_flash_attn_varlen_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
......@@ -1339,7 +1373,7 @@ def test_flash_attn_varlen_output(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if kvpacked:
(
dq_unpad,
......@@ -1396,9 +1430,9 @@ def test_flash_attn_varlen_output(
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
......@@ -1917,9 +1951,11 @@ def test_flash_attn_kvcache(
cache_seqlens = torch.randint(
0 if new_kv else 1,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
if new_kv
else (seqlen_k + 1),
(
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
if new_kv
else (seqlen_k + 1)
),
(batch_size,),
dtype=torch.int32,
device=device,
......@@ -2455,12 +2491,12 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
assert torch.equal(dv, dv)
assert torch.equal(dk, dk)
assert torch.equal(dq, dq)
assert torch.equal(dv, dv0)
assert torch.equal(dk, dk0)
assert torch.equal(dq, dq0)
@pytest.mark.parametrize("dtype", [torch.float16])
......
......@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN pip install flash-attn==2.5.7
RUN pip install flash-attn==2.6.0
# Install CUDA extensions for fused dense
RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.7#subdirectory=csrc/fused_dense_lib
RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.6.0#subdirectory=csrc/fused_dense_lib
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment