Unverified Commit d562aa63 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Sync with FA v2.6.0 to support soft capping (#13)

parent 12375706
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -16,17 +16,18 @@ DTYPE_MAP = { ...@@ -16,17 +16,18 @@ DTYPE_MAP = {
SM = [80] # Sm80 kernels support up to SM = [80] # Sm80 kernels support up to
HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256] HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256]
IS_CAUSAL = ["false", "true"]
KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h" KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params &params, cudaStream_t stream) {{ void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream); run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
}} }}
""" """
KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h" KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);
""" """
KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h" KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
...@@ -43,13 +44,14 @@ class Kernel: ...@@ -43,13 +44,14 @@ class Kernel:
sm: int sm: int
dtype: str dtype: str
head_dim: int head_dim: int
is_causal: bool
direction: str direction: str
@property @property
def template(self) -> str: def template(self) -> str:
if self.direction == "fwd": if self.direction == "fwd":
return KERNEL_IMPL_TEMPLATE_FWD.format( return KERNEL_IMPL_TEMPLATE_FWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
) )
elif self.direction == "bwd": elif self.direction == "bwd":
return KERNEL_IMPL_TEMPLATE_BWD.format( return KERNEL_IMPL_TEMPLATE_BWD.format(
...@@ -57,18 +59,21 @@ class Kernel: ...@@ -57,18 +59,21 @@ class Kernel:
) )
else: else:
return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format( return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
) )
@property @property
def filename(self) -> str: def filename(self) -> str:
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu" return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu"
def get_all_kernels() -> List[Kernel]: def get_all_kernels() -> List[Kernel]:
for direction in ["fwd", "fwd_split"]:
for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
for direction in ["bwd"]:
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM): for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
for direction in ["fwd", "bwd", "fwd_split"]: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal="false", direction=direction)
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction)
def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "cute/algorithm/copy.hpp" #include "cute/tensor.hpp"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h" #include "cutlass/layout/layout.h"
......
...@@ -13,7 +13,7 @@ using namespace cute; ...@@ -13,7 +13,7 @@ using namespace cute;
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k, __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) { const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
...@@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, ...@@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor,
const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride, const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) { const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
...@@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx( ...@@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{ {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 2, "Only support 2D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include <cute/algorithm/copy.hpp> #include <cute/tensor.hpp>
#include "utils.h" #include "utils.h"
......
...@@ -56,6 +56,16 @@ ...@@ -56,6 +56,16 @@
#define EVENK_SWITCH BOOL_SWITCH #define EVENK_SWITCH BOOL_SWITCH
#endif #endif
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define SOFTCAP_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_LOCAL #ifdef FLASHATTENTION_DISABLE_LOCAL
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \ #define LOCAL_SWITCH(COND, CONST_NAME, ...) \
[&] { \ [&] { \
......
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#endif #endif
#include <cute/algorithm/copy.hpp> #include <cute/tensor.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/array.h> #include <cutlass/array.h>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
......
...@@ -151,22 +151,22 @@ if not SKIP_CUDA_BUILD: ...@@ -151,22 +151,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
# "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
...@@ -183,6 +183,22 @@ if not SKIP_CUDA_BUILD: ...@@ -183,6 +183,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
], ],
extra_compile_args={ extra_compile_args={
"cxx": ["-O3", "-std=c++17"] + generator_flag, "cxx": ["-O3", "-std=c++17"] + generator_flag,
...@@ -203,6 +219,7 @@ if not SKIP_CUDA_BUILD: ...@@ -203,6 +219,7 @@ if not SKIP_CUDA_BUILD:
# "-DFLASHATTENTION_DISABLE_BACKWARD", # "-DFLASHATTENTION_DISABLE_BACKWARD",
"-DFLASHATTENTION_DISABLE_DROPOUT", "-DFLASHATTENTION_DISABLE_DROPOUT",
# "-DFLASHATTENTION_DISABLE_ALIBI", # "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
"-DFLASHATTENTION_DISABLE_UNEVEN_K", "-DFLASHATTENTION_DISABLE_UNEVEN_K",
# "-DFLASHATTENTION_DISABLE_LOCAL", # "-DFLASHATTENTION_DISABLE_LOCAL",
] ]
......
...@@ -216,6 +216,7 @@ def attention_ref( ...@@ -216,6 +216,7 @@ def attention_ref(
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
): ):
...@@ -233,7 +234,7 @@ def attention_ref( ...@@ -233,7 +234,7 @@ def attention_ref(
window_size: (int, int), left and right window size window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16. output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation without changing the math. This is to estimate the numerical error from operation
reordering. reordering.
Output: Output:
...@@ -253,6 +254,10 @@ def attention_ref( ...@@ -253,6 +254,10 @@ def attention_ref(
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else: else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if softcap > 0:
scores /= softcap
scores = scores.tanh()
scores *= softcap
if key_padding_mask is not None: if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0: if window_size[0] >= 0 or window_size[1] >= 0:
...@@ -298,6 +303,7 @@ def attention_kvpacked_ref( ...@@ -298,6 +303,7 @@ def attention_kvpacked_ref(
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
): ):
...@@ -313,6 +319,7 @@ def attention_kvpacked_ref( ...@@ -313,6 +319,7 @@ def attention_kvpacked_ref(
upcast=upcast, upcast=upcast,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops, reorder_ops=reorder_ops,
) )
...@@ -325,6 +332,7 @@ def attention_qkvpacked_ref( ...@@ -325,6 +332,7 @@ def attention_qkvpacked_ref(
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
): ):
...@@ -340,6 +348,7 @@ def attention_qkvpacked_ref( ...@@ -340,6 +348,7 @@ def attention_qkvpacked_ref(
upcast=upcast, upcast=upcast,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops, reorder_ops=reorder_ops,
) )
...@@ -877,23 +886,29 @@ def test_flash_attn_varlen_qkvpacked( ...@@ -877,23 +886,29 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.17]) # @pytest.mark.parametrize("dropout_p", [0.17])
@pytest.mark.parametrize("softcap", [0.0, 50.0])
def test_flash_attn_output( def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
): ):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
): ):
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
if softcap > 0.0 and dropout_p > 0.0:
pytest.skip("Softcap and dropout not supported together")
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 4 batch_size = 4
nheads = 9 nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap
if kvpacked: if kvpacked:
kv = torch.randn( kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
...@@ -918,6 +933,7 @@ def test_flash_attn_output( ...@@ -918,6 +933,7 @@ def test_flash_attn_output(
dropout_p, dropout_p,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic, deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
...@@ -930,6 +946,7 @@ def test_flash_attn_output( ...@@ -930,6 +946,7 @@ def test_flash_attn_output(
dropout_p, dropout_p,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic, deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
...@@ -984,6 +1001,7 @@ def test_flash_attn_output( ...@@ -984,6 +1001,7 @@ def test_flash_attn_output(
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
) )
out_pt, attn_pt = attention_kvpacked_ref( out_pt, attn_pt = attention_kvpacked_ref(
q, q,
...@@ -995,6 +1013,7 @@ def test_flash_attn_output( ...@@ -995,6 +1013,7 @@ def test_flash_attn_output(
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
...@@ -1010,6 +1029,7 @@ def test_flash_attn_output( ...@@ -1010,6 +1029,7 @@ def test_flash_attn_output(
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
) )
out_pt, attn_pt = attention_ref( out_pt, attn_pt = attention_ref(
q, q,
...@@ -1022,6 +1042,7 @@ def test_flash_attn_output( ...@@ -1022,6 +1042,7 @@ def test_flash_attn_output(
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
...@@ -1036,7 +1057,7 @@ def test_flash_attn_output( ...@@ -1036,7 +1057,7 @@ def test_flash_attn_output(
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if kvpacked: if kvpacked:
( (
dq, dq,
...@@ -1092,7 +1113,7 @@ def test_flash_attn_output( ...@@ -1092,7 +1113,7 @@ def test_flash_attn_output(
if not alibi: if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
...@@ -1133,24 +1154,31 @@ def test_flash_attn_output( ...@@ -1133,24 +1154,31 @@ def test_flash_attn_output(
) )
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize("softcap", [0.0, 50.0])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_output( def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
): ):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
): ):
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
if softcap > 0.0 and dropout_p > 0.0:
pytest.skip("Softcap and dropout not supported together")
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 4 batch_size = 4
nheads = 9 nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap
if kvpacked: if kvpacked:
kv = torch.randn( kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
...@@ -1198,6 +1226,7 @@ def test_flash_attn_varlen_output( ...@@ -1198,6 +1226,7 @@ def test_flash_attn_varlen_output(
dropout_p, dropout_p,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic, deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
...@@ -1229,6 +1258,7 @@ def test_flash_attn_varlen_output( ...@@ -1229,6 +1258,7 @@ def test_flash_attn_varlen_output(
dropout_p, dropout_p,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic, deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
...@@ -1288,6 +1318,7 @@ def test_flash_attn_varlen_output( ...@@ -1288,6 +1318,7 @@ def test_flash_attn_varlen_output(
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
) )
out_pt, attn_pt = attention_kvpacked_ref( out_pt, attn_pt = attention_kvpacked_ref(
q, q,
...@@ -1299,6 +1330,7 @@ def test_flash_attn_varlen_output( ...@@ -1299,6 +1330,7 @@ def test_flash_attn_varlen_output(
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
...@@ -1314,6 +1346,7 @@ def test_flash_attn_varlen_output( ...@@ -1314,6 +1346,7 @@ def test_flash_attn_varlen_output(
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
) )
out_pt, attn_pt = attention_ref( out_pt, attn_pt = attention_ref(
q, q,
...@@ -1326,6 +1359,7 @@ def test_flash_attn_varlen_output( ...@@ -1326,6 +1359,7 @@ def test_flash_attn_varlen_output(
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
...@@ -1339,7 +1373,7 @@ def test_flash_attn_varlen_output( ...@@ -1339,7 +1373,7 @@ def test_flash_attn_varlen_output(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out) g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if kvpacked: if kvpacked:
( (
dq_unpad, dq_unpad,
...@@ -1396,9 +1430,9 @@ def test_flash_attn_varlen_output( ...@@ -1396,9 +1430,9 @@ def test_flash_attn_varlen_output(
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi: if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
...@@ -1917,9 +1951,11 @@ def test_flash_attn_kvcache( ...@@ -1917,9 +1951,11 @@ def test_flash_attn_kvcache(
cache_seqlens = torch.randint( cache_seqlens = torch.randint(
0 if new_kv else 1, 0 if new_kv else 1,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
if new_kv if new_kv
else (seqlen_k + 1), else (seqlen_k + 1)
),
(batch_size,), (batch_size,),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
...@@ -2455,12 +2491,12 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus ...@@ -2455,12 +2491,12 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
g = torch.randn_like(out) g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
for _ in range(50): for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
assert torch.equal(dv, dv) assert torch.equal(dv, dv0)
assert torch.equal(dk, dk) assert torch.equal(dk, dk0)
assert torch.equal(dq, dq) assert torch.equal(dq, dq0)
@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("dtype", [torch.float16])
......
...@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr ...@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention # Install FlashAttention
RUN pip install flash-attn==2.5.7 RUN pip install flash-attn==2.6.0
# Install CUDA extensions for fused dense # Install CUDA extensions for fused dense
RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.7#subdirectory=csrc/fused_dense_lib RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.6.0#subdirectory=csrc/fused_dense_lib
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment