Commit b1fbbd83 authored by Tri Dao's avatar Tri Dao
Browse files

Implement splitKV attention

parent 7a983df7
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -16,14 +16,21 @@ DTYPE_MAP = {
SM = [80] # Sm80 kernels support up to
HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256]
KERNEL_IMPL_TEMPLATE_FWD = """
KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params &params, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
}}
"""
KERNEL_IMPL_TEMPLATE_BWD = """
KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params &params, cudaStream_t stream);
"""
KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure);
......@@ -44,10 +51,14 @@ class Kernel:
return KERNEL_IMPL_TEMPLATE_FWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
)
else:
elif self.direction == "bwd":
return KERNEL_IMPL_TEMPLATE_BWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
)
else:
return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
)
@property
def filename(self) -> str:
......@@ -56,7 +67,7 @@ class Kernel:
def get_all_kernels() -> List[Kernel]:
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
for direction in ["fwd", "bwd"]:
for direction in ["fwd", "bwd", "fwd_split"]:
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction)
......@@ -65,8 +76,7 @@ def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"\n
"""
include = f'#include "flash_{kernel.direction}_launch_template.h"\n'
(autogen_dir / kernel.filename).write_text(prelude + include + kernel.template)
(autogen_dir / kernel.filename).write_text(prelude + kernel.template)
def main(output_dir: Optional[str]) -> None:
......
......@@ -113,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
......@@ -158,6 +159,17 @@ struct Flash_fwd_kernel_traits : public Base {
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomOaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
Stride< _8, _1>>,
Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
......
......@@ -173,6 +173,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"] + generator_flag,
......
......@@ -1367,6 +1367,109 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(3, 1024),
(1, 339),
(3, 799),
(64, 2048),
(16, 20000),
(16, 100000),
(128, 128),
(256, 256),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 1
nheads = 12
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
out, lse, _ = flash_attn_func(q, k, v, 0.0, causal=causal, return_attn_probs=True)
out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal)
out_pt, attn_pt = attention_ref(
q,
k,
v,
None,
None,
0.0,
None,
causal=causal,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
(
dq,
dk,
dv,
) = torch.autograd.grad(out, (q, k, v), g)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 2e-4
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment