Commit 4f285b35 authored by Tri Dao's avatar Tri Dao
Browse files

FlashAttention-2 release

parent 6d48e14a
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// and https://github.com/facebookresearch/xformers/blob/main/xformers/csrc/attention/cuda/fmha/gemm_kernel_utils.h#L8
#pragma once
......@@ -10,31 +9,57 @@
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, ([&] {
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// }));
/// });
/// ```
/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro.
#define BOOL_SWITCH(COND, CONST_NAME, F) \
{ \
if (COND) { \
constexpr bool CONST_NAME = true; \
F(); \
} else { \
constexpr bool CONST_NAME = false; \
F(); \
} \
}
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
// modified from BOOL_SWITCH
// because MSVC cannot handle std::conditional with constexpr variable
#define FP16_SWITCH(COND, F) \
{ \
if (COND) { \
using elem_type = __nv_bfloat16; \
F(); \
} else { \
using elem_type = __half; \
F(); \
} \
}
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = cutlass::half_t; \
return __VA_ARGS__(); \
} else { \
using elem_type = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
} \
}()
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr int kHeadDim = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 64) { \
constexpr int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 96) { \
constexpr int kHeadDim = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr int kHeadDim = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 160) { \
constexpr int kHeadDim = 160; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 192) { \
constexpr int kHeadDim = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 224) { \
constexpr int kHeadDim = 224; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 256) { \
constexpr int kHeadDim = 256; \
return __VA_ARGS__(); \
} \
}()
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint32_t relu2(const uint32_t x);
template<>
inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
#else
asm volatile( \
"{\n" \
"\t .reg .f16x2 sela;\n" \
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
"\t and.b32 %0, sela, %1;\n"
"}\n" : "=r"(res) : "r"(x), "r"(zero));
#endif
return res;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
return res;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<typename T>
inline __device__ uint32_t convert_relu2(const float2 x);
template<>
inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
return res;
}
template<>
inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
return res;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float2 half2_unpack(uint32_t a);
template <>
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
return __half22float2(reinterpret_cast<__half2 (&)>(a));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert two half2's or bf162's into float, then take their dot product.
template <typename T>
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
float2 af = flash::half2_unpack<T>(a);
float2 bf = flash::half2_unpack<T>(b);
return af.x * bf.x + af.y * bf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template<typename T>
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
float sum;
sum = flash::hfma2_to_float<T>(a.x, b.x);
sum += flash::hfma2_to_float<T>(a.y, b.y);
sum += flash::hfma2_to_float<T>(a.z, b.z);
sum += flash::hfma2_to_float<T>(a.w, b.w);
return sum;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy0, typename TiledCopy1>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout>
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
get<0, 1>(l),
get<1, 1, 1>(l));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
constexpr int numel = decltype(size(tensor))::value;
static_assert(numel % 2 == 0);
using value_t = typename Engine::value_type;
// HACK: this requires tensor to be "contiguous"
Tensor tensor_uint32 = recast<uint32_t>(tensor);
#pragma unroll
for (int i = 0; i < size(tensor_uint32); ++i) {
tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
static_assert(std::is_same_v<float, From_type>);
constexpr int numel = decltype(size(tensor))::value;
static_assert(numel % 2 == 0);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// HACK: this requires tensor to be "contiguous"
Tensor tensor_float2 = recast<float2>(tensor);
Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());
#pragma unroll
for (int i = 0; i < size(out_uint32); ++i) {
out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));
}
Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());
#else
Tensor out = flash::convert_type<To_type>(tensor);
flash::relu_(out);
#endif
return out;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template <int N>
CUTE_HOST_DEVICE
void cp_async_wait() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
copy(thr_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
clear(D(_, m, _));
}
}
// TD [2023-04-13]: Strange that the code below can cause race condition.
// I think it's because the copies are under an if statement.
// if (Is_even_K) {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, _));
// }
// }
// } else { // It's slightly faster in this case if iterate over K first
// #pragma unroll
// for (int k = 0; k < size<2>(S); ++k) {
// if (predicate_K(k)) {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, k));
// }
// }
// } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN
// if (Clear_OOB_MN || Is_even_MN) {
// clear(D(_, _, k));
// } else {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {
// clear(D(_, m, k));
// }
// }
// }
// }
// }
// }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
__version__ = "1.0.9"
__version__ = "2.0.0.post1"
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_func
import math
import torch
import torch.nn as nn
from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
return output, None
class FlashMHA(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
causal=False, device=None, dtype=None) -> None:
assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = FlashAttention(attention_dropout=attention_dropout)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, key_padding_mask=None, need_weights=False):
"""x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
key_padding_mask: bool tensor of shape (batch, seqlen)
"""
qkv = self.Wqkv(x)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
need_weights=need_weights, causal=self.causal)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights
import torch
import torch.nn as nn
import torch.nn.functional as F
import flash_attn_cuda
def _get_block_size(device, head_dim, is_dropout):
assert head_dim % 8 == 0 and head_dim <= 128
return 256 if head_dim <= 64 else 128
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_softmax, num_splits=0,
generator=None):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
Don't change it unless you know what you're doing.
"""
softmax_lse, rng_state, *rest = flash_attn_cuda.fwd(
q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, False, causal, return_softmax, num_splits, generator
import flash_attn_2_cuda as flash_attn_cuda
from einops import rearrange
def _get_block_size(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel
assert head_dim <= 256
major, minor = torch.cuda.get_device_capability(device)
is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
is_sm80 = major == 8 and minor == 0
is_sm90 = major == 9 and minor == 0
if head_dim <= 32:
return 128, 128
if head_dim <= 64:
return (128, 128) if not is_dropout else (128, 64)
elif head_dim <= 96:
return (64, 64) if (is_sm8x and is_causal) else (128, 64)
elif head_dim <= 128:
if is_sm8x:
return (64, 64) if (not is_dropout and is_causal) else (128, 32)
else:
return 128, (64 if not is_dropout else 32)
elif head_dim <= 160:
if is_sm8x:
return (128, 64) if not is_causal else (64, 64)
else:
return 128, 32
elif head_dim <= 192:
return (128, 64) if not is_dropout else (64, 64)
elif head_dim <= 224:
return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
elif head_dim <= 256:
return (128, 64) if is_sm80 else (64, 64)
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
if q.stride(-1) != 1:
q = q.contiguous()
if k.stride(-1) != 1:
k = k.contiguous()
if v.stride(-1) != 1:
v = v.contiguous()
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
)
return out, q, k, v, out_padded, softmax_lse, S_dmask
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_softmax):
if q.stride(-1) != 1:
q = q.contiguous()
if k.stride(-1) != 1:
k = k.contiguous()
if v.stride(-1) != 1:
v = v.contiguous()
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, False, causal, return_softmax, None
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return out, softmax_lse, rng_state, S_dmask
return out, q, k, v, out_padded, softmax_lse, S_dmask
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
rng_state=None, num_splits=0, generator=None):
"""
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel
as num_splits=3), so effectively the choices are 0, 1, and 2.
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
"""
dout = dout.contiguous() # CUDA code assumes that dout is contiguous
_, _, _, softmax_d = flash_attn_cuda.bwd(
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
dropout_p, softmax_scale, causal):
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None
)
return dq, dk, dv, softmax_d
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal):
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal,
num_splits, generator, rng_state)
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d
......@@ -51,191 +89,249 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal,
return_softmax, deterministic):
def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal,
return_softmax=return_softmax
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale,
causal=causal, return_softmax=return_softmax and dropout_p > 0
)
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
dqkv = torch.empty_like(qkv)
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
_flash_attn_backward(
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
ctx.dropout_p, ctx.softmax_scale, ctx.causal
)
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
)
return dqkv, None, None, None, None, None, None, None
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
_flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
ctx.dropout_p, ctx.softmax_scale, ctx.causal
)
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax, deterministic):
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
q, kv[:, 0], kv[:, 1], torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal,
return_softmax=return_softmax and dropout_p > 0
)
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
_flash_attn_backward(
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
dout, q, k, v, out, softmax_lse,
dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal
)
return dq, dkv, None, None, None, None, None, None, None, None, None
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None
class FlashAttnFunc(torch.autograd.Function):
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax, deterministic):
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
)
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q)
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
_flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, ctx.causal
)
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal=causal,
return_softmax=return_softmax and dropout_p > 0
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
dout, q, k, v, out, softmax_lse,
dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal
)
return dq, dk, dv, None, None, None, None, None, None, None, None, None
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]]
dv = dv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dk, dv, None, None, None, None, None, None, None, None
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
class FlashAttnVarlenFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
softmax_scale, causal, return_softmax, deterministic):
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
if dropout_p > 0:
rng_state0 = torch.cuda.get_rng_state()
generator1 = torch.Generator(device='cuda')
rng_state1 = generator1.get_state()
else:
rng_state0, generator1, rng_state1 = None, None, None
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out = torch.empty_like(qkv[:, 0])
_, softmax_lse0, S_dmask0 = _flash_attn_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[:batch_size0 + 1],
cu_seqlens[:batch_size0 + 1], max_seqlen0, max_seqlen0, dropout_p, softmax_scale,
causal=causal, return_softmax=return_softmax
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
)
s = torch.cuda.Stream()
with torch.cuda.stream(s):
_, softmax_lse1, S_dmask1 = _flash_attn_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[batch_size0:],
cu_seqlens[batch_size0:], max_seqlen1, max_seqlen1, dropout_p, softmax_scale,
causal=causal, return_softmax=return_softmax, generator=generator1
)
torch.cuda.current_stream().wait_stream(s)
ctx.save_for_backward(qkv, out, softmax_lse0, softmax_lse1, cu_seqlens,
rng_state0, rng_state1)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen0 = max_seqlen0
ctx.max_seqlen1 = max_seqlen1
ctx.batch_size0 = batch_size0
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
if not return_softmax:
return out
else:
max_seqlen_q = max(softmax_lse0.shape[2], softmax_lse1.shape[2])
max_seqlen_k = max(S_dmask0.shape[3], S_dmask1.shape[3])
softmax_lse = torch.cat([F.pad(softmax_lse0, (0, max_seqlen_q - softmax_lse0.shape[2])),
F.pad(softmax_lse1, (0, max_seqlen_q - softmax_lse1.shape[2]))],
dim=0)
return out, softmax_lse, S_dmask0, S_dmask1
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
qkv, out, softmax_lse0, softmax_lse1, cu_seqlens, rng_state0, rng_state1 = ctx.saved_tensors
batch_size0 = ctx.batch_size0
if rng_state0 is not None:
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state0)
if rng_state1 is not None:
generator1 = torch.Generator(device='cuda')
generator1.set_state(rng_state1)
else:
generator1 = None
dqkv = torch.empty_like(qkv)
_flash_attn_backward(
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
ctx.softmax_scale, ctx.causal, num_splits=1 if ctx.deterministic else 0,
torch.cuda.set_rng_state(rng_state)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
)
s = torch.cuda.Stream()
with torch.cuda.stream(s):
_flash_attn_backward(
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
ctx.softmax_scale, ctx.causal, generator=generator1,
num_splits=1 if ctx.deterministic else 0,
)
torch.cuda.current_stream().wait_stream(s)
if rng_state0 is not None:
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]]
dv = dv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False, deterministic=False):
def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
......@@ -243,9 +339,8 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total, nheads, headdim).
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
......@@ -253,23 +348,87 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
causal, return_attn_probs, deterministic)
return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale=None, causal=False,
return_attn_probs=False, deterministic=False):
def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs)
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)
def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None,
causal=False, return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
......@@ -277,9 +436,8 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total_q, nheads, headdim).
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
......@@ -287,19 +445,26 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
return_attn_probs, deterministic)
return FlashAttnVarlenQKVPackedFunc.apply(
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
)
def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False,
deterministic=False):
def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
......@@ -313,9 +478,8 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total_q, nheads, headdim).
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
......@@ -323,27 +487,31 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_attn_probs, deterministic)
return FlashAttnVarlenKVPackedFunc.apply(
q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_attn_probs
)
def flash_attn_unpadded_qkvpacked_split_func(
qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False, deterministic=False):
"""
Split attention into 2 kernels running on 2 separate streams for performance reason:
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
have one kernel dealing with seqlen <= 128 and one kernel for seqlen > 128.
dropout_p should be set to 0.0 during evaluation.
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in K, V must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen0: int. Maximum sequence length in 1st part of the batch.
max_seqlen1: int. Maximum sequence length in 2nd part of the batch.
batch_size0: int. Number of sequences in the 1st part of the batch.
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
......@@ -351,7 +519,6 @@ def flash_attn_unpadded_qkvpacked_split_func(
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
......@@ -361,15 +528,7 @@ def flash_attn_unpadded_qkvpacked_split_func(
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
dropout_p, softmax_scale, causal, return_attn_probs,
deterministic)
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
return_attn_probs=False):
"""For backward-compatibility only, will remove soon.
dropout_p should be set to 0.0 during evaluation
"""
return flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_s, dropout_p, softmax_scale,
causal, return_attn_probs)
return FlashAttnVarlenFunc.apply(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_attn_probs
)
......@@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import StochasticDepth
# from torchvision.ops import StochasticDepth
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp
......@@ -70,12 +70,12 @@ class Block(nn.Module):
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
# self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout2)
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
# self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
......@@ -129,13 +129,14 @@ class Block(nn.Module):
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
rowscale1 = None
# if self.drop_path1.p == 0 or not self.training:
# rowscale1 = None
# else:
# rowscale1 = self.drop_path1(torch.ones(
# hidden_states.shape[:-1], device=hidden_states.device,
# dtype=hidden_states.dtype)
# )
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
......@@ -156,13 +157,14 @@ class Block(nn.Module):
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
# if self.drop_path2.p == 0 or not self.training:
# rowscale2 = None
# else:
# rowscale2 = self.drop_path2(torch.ones(
# hidden_states.shape[:-1], device=hidden_states.device,
# dtype=hidden_states.dtype)
# )
rowscale2 = None
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps,
......
......@@ -10,14 +10,10 @@ import torch.nn.functional as F
from einops import rearrange
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
except ImportError:
flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func = None, None
try:
from flash_attn.ops.flash_attn_triton import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
except ImportError:
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
try:
......@@ -46,17 +42,13 @@ class FlashSelfAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
triton=False):
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
if attention_dropout != 0.0 or not triton:
assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed'
if attention_dropout == 0.0 and triton:
assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed'
assert flash_attn_varlen_qkvpacked_func is not None, 'FlashAttention is not installed'
assert flash_attn_qkvpacked_func is not None, 'FlashAttention is not installed'
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.triton = triton
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
"""Implements the multihead softmax attention.
......@@ -83,26 +75,13 @@ class FlashSelfAttention(nn.Module):
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
return flash_attn_unpadded_qkvpacked_func(
return flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
else:
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
# Triton version doesn't support dropout
if self.triton and (self.drop.p == 0 or not self.training):
output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale)
else:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_seqlen = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
return flash_attn_qkvpacked_func(qkv, self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal)
class FlashCrossAttention(nn.Module):
......@@ -115,17 +94,13 @@ class FlashCrossAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
triton=False):
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
if attention_dropout != 0.0 or not triton:
assert flash_attn_unpadded_kvpacked_func is not None, 'FlashAttention is not installed'
if attention_dropout == 0.0 and triton:
assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed'
assert flash_attn_varlen_kvpacked_func is not None, 'FlashAttention is not installed'
assert flash_attn_kvpacked_func is not None, 'FlashAttention is not installed'
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.triton = triton
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
cu_seqlens_k=None, max_seqlen_k=None):
......@@ -133,7 +108,7 @@ class FlashCrossAttention(nn.Module):
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
......@@ -154,7 +129,7 @@ class FlashCrossAttention(nn.Module):
assert cu_seqlens_k.dtype == torch.int32
assert max_seqlen_k is not None
assert isinstance(max_seqlen, int)
return flash_attn_unpadded_kvpacked_func(
return flash_attn_varlen_kvpacked_func(
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
......@@ -162,23 +137,9 @@ class FlashCrossAttention(nn.Module):
else:
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
if self.triton and (self.drop.p == 0.0 or not self.training): # Triton version doesn't support dropout
output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale)
else:
q = rearrange(q, 'b s ... -> (b s) ...')
kv = rearrange(kv, 'b s ... -> (b s) ...')
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q,
dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k,
dtype=torch.int32, device=kv.device)
output = flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
return flash_attn_kvpacked_func(q, kv, self.drop.p if self.training else 0.0,
causal=causal, softmax_scale=self.softmax_scale)
class SelfAttention(nn.Module):
......
......@@ -111,28 +111,52 @@ cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("FlashAttention is only supported on CUDA 11 and above")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_75,code=sm_75")
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
subprocess.run(["git", "submodule", "update", "--init", "csrc/flash_attn/cutlass"])
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
ext_modules.append(
CUDAExtension(
name="flash_attn_cuda",
name="flash_attn_2_cuda",
sources=[
"csrc/flash_attn/fmha_api.cpp",
"csrc/flash_attn/src/fmha_fwd_hdim32.cu",
"csrc/flash_attn/src/fmha_fwd_hdim64.cu",
"csrc/flash_attn/src/fmha_fwd_hdim128.cu",
"csrc/flash_attn/src/fmha_bwd_hdim32.cu",
"csrc/flash_attn/src/fmha_bwd_hdim64.cu",
"csrc/flash_attn/src/fmha_bwd_hdim128.cu",
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu",
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
"csrc/flash_attn/flash_api.cpp",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"] + generator_flag,
......@@ -157,11 +181,12 @@ ext_modules.append(
include_dirs=[
Path(this_dir) / 'csrc' / 'flash_attn',
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include',
Path(this_dir) / 'csrc' / 'cutlass' / 'include',
],
)
)
def get_package_version():
with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
......@@ -172,6 +197,7 @@ def get_package_version():
else:
return str(public_version)
setup(
name="flash_attn",
version=get_package_version(),
......@@ -179,11 +205,9 @@ setup(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
),
author="Tri Dao",
author_email="trid@stanford.edu",
author_email="trid@cs.stanford.edu",
description="Flash Attention: Fast and Memory-Efficient Exact Attention",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/HazyResearch/flash-attention",
url="https://github.com/Dao-AILab/flash-attention",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
......
import math
from functools import partial
import torch
import torch.nn.functional as F
......@@ -8,100 +7,87 @@ import pytest
from einops import rearrange, repeat
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_qkvpacked_func, _get_block_size, flash_attn_unpadded_kvpacked_func, flash_attn_unpadded_func
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
from flash_attn import flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
from flash_attn import flash_attn_varlen_func
from flash_attn.flash_attn_interface import _get_block_size
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
try:
from flash_attn.flash_attn_triton import flash_attn_func
except (ImportError, AttributeError): # Older version of Triton doesn't have tl.constexpr
flash_attn_func = None
MAX_HEADDIM_SM8x = 192
is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5)
is_sm8x = torch.cuda.get_device_capability('cuda')[0] == 8
is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0)
is_sm90 = torch.cuda.get_device_capability('cuda') == (9, 0)
def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'):
assert mode in ['full', 'random', 'third', 'split']
assert mode in ['full', 'random', 'third']
if mode == 'full':
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == 'random':
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device)
elif mode == 'third':
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
elif mode == 'split':
lengths0 = torch.randint(min(128, max_seqlen), max_seqlen + 1,
(batch_size // 4 * 3, 1), device=device)
lengths1 = torch.randint(min(max(1, max_seqlen - 20), 128), min(max_seqlen, 128) + 1,
(batch_size - batch_size // 4 * 3, 1), device=device)
lengths = torch.cat([lengths0, lengths1], dim=0)
lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device)
padding_mask = repeat(torch.arange(max_seqlen, device=device), 's -> b s', b=batch_size) < lengths
return padding_mask
def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None,
def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None,
kvpacked=False, qkvpacked=False):
"""
Arguments:
x: (batch_size, seqlen, nheads * d)
Wqkv: nn.Linear(nheads * d, 3 * nheads * d)
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen, dim = x.shape
q, k, v = Wqkv(x).chunk(3, dim=-1)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
q_unpad = rearrange(q_unpad, 'nnz (h d) -> nnz h d', h=nheads)
output_pad_fn = lambda output_unpad: rearrange(
pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads
)
output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q)
else:
q_unpad = rearrange(q, 'b s (h d) -> (b s) h d', h=nheads)
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
q_unpad = rearrange(q, 'b s h d -> (b s) h d')
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q_unpad.device)
max_seqlen_q = seqlen
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(output_unpad, '(b s) h d -> b s h d', b=batch_size)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
k_unpad = rearrange(k_unpad, 'nnz (h d) -> nnz h d', h=nheads)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
v_unpad = rearrange(v_unpad, 'nnz (h d) -> nnz h d', h=nheads)
else:
k_unpad = rearrange(k, 'b s (h d) -> (b s) h d', h=nheads)
v_unpad = rearrange(v, 'b s (h d) -> (b s) h d', h=nheads)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=q_unpad.device)
max_seqlen_k = seqlen
k_unpad = rearrange(k, 'b s h d -> (b s) h d')
v_unpad = rearrange(v, 'b s h d -> (b s) h d')
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=k_unpad.device)
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = rearrange(torch.stack([q, k, v], dim=2), 'b s t (h d) -> b s t h d', h=nheads)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
pad_input(rearrange(dqkv_unpad, 'nnz t h d -> nnz (t h d)'), indices_q, batch_size, seqlen),
'b s (t h d) -> b s t h d', t=3, h=nheads
)
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, '(b s) t h d -> b s t h d', b=batch_size)
return (qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q,
qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
q = rearrange(q, 'b s (h d) -> b s h d', h=nheads)
kv = rearrange(torch.stack([k, v], dim=2), 'b s t (h d) -> b s t h d', h=nheads)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: rearrange(
pad_input(rearrange(dkv_unpad, 'nnz t h d -> nnz (t h d)'), indices_k, batch_size, seqlen),
'b s (t h d) -> b s t h d', t=2, h=nheads
)
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else:
dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, '(b s) t h d -> b s t h d', b=batch_size)
return (q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(),
......@@ -109,35 +95,30 @@ def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None
q.detach().requires_grad_(), kv.detach().requires_grad_(),
output_pad_fn, dq_pad_fn, dkv_pad_fn)
else:
q, k, v = [rearrange(z, 'b s (h d) -> b s h d', h=nheads).detach().requires_grad_()
for z in [q, k, v]]
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dk_pad_fn = lambda dk_unpad: rearrange(
pad_input(rearrange(dk_unpad, 'nnz h d -> nnz (h d)'), indices_k, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads
)
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
else:
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, '(b s) h d -> b s h d', b=batch_size)
return (q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
q, k, v,
q.detach().requires_grad_(), k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn, dq_pad_fn, dk_pad_fn)
def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0,
dropout_mask=None, causal=False, bias=None, upcast=True, reorder_ops=False):
dropout_mask=None, causal=False, upcast=True, reorder_ops=False):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads, head_dim)
v: (batch_size, seqlen_k, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
bias: (batch_size, nheads, seqlen_q, seqlen_k)
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
......@@ -151,13 +132,13 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k)
else:
scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
if bias is not None:
scores = (scores + bias).to(dtype=scores.dtype)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf'))
if causal:
......@@ -238,37 +219,40 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea
causal=False):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q, seqlen_k)
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
"""
S_flat = rearrange(S, 'b h t s -> b h (t s)')
seqlen_q, seqlen_k = S.shape[-2:]
block_size = _get_block_size(S.device, head_dim, is_dropout)
loop_steps = (seqlen_k + block_size - 1) // block_size
warps_n = 4
mmas_n = (seqlen_k // warps_n // 16) if seqlen_k <= block_size else (block_size // warps_n // 16)
S_converted = rearrange(S_flat, 'b h (loop nsteps mmas_n warps_n eight t r c0 c1) -> b h (nsteps r eight) (loop mmas_n warps_n c0 t c1)',
loop=loop_steps, nsteps=seqlen_q // 16, mmas_n=mmas_n, warps_n=warps_n, eight=8, t=4,
r=2, c0=2, c1=2)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og = query_padding_mask.shape[-1]
if seqlen_q_og < seqlen_q:
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
else:
query_padding_mask = query_padding_mask[:, :seqlen_q]
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
seqlen_k_og = key_padding_mask.shape[-1]
if seqlen_k_og < seqlen_k:
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
else:
key_padding_mask = key_padding_mask[:, :seqlen_k]
S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), 0.0)
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal)
nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n
nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m
mmas_n = (blocksize_n + 16 - 1) // 16
S_flat = rearrange(S, 'b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)',
blocksize_m=blocksize_m, blocksize_n=blocksize_n)
S_converted = rearrange(S_flat, 'b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)',
mmas_n=mmas_n, warps_n=warps_n, eight=8, c0=2, c1=2, c2=2, four=4)
if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1)
S_converted.masked_fill_(causal_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q
if query_padding_mask is not None:
if seqlen_q_og < seqlen_q:
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
else:
query_padding_mask = query_padding_mask[:, :seqlen_q]
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
if key_padding_mask is not None:
if seqlen_k_og < seqlen_k:
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
else:
key_padding_mask = key_padding_mask[:, :seqlen_k]
S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), 0.0)
if seqlen_q_og < seqlen_q:
S_converted = S_converted[:, :, :seqlen_q_og, :]
else:
......@@ -300,16 +284,15 @@ def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_pa
if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
scores.masked_fill_(causal_mask, float('-inf'))
block_size = _get_block_size(scores.device, head_dim, is_dropout)
scores_block = scores.split(block_size, dim=-1)
_, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
lcse_block = torch.logcumsumexp(lse_block, dim=-1).unbind(dim=-1)
scores_max_block = ([torch.amax(scores_block[0], dim=-1)]
+ [torch.maximum(torch.amax(s, dim=-1), lcse)
for s, lcse in zip(scores_block[1:], lcse_block[:-1])])
attn_unnorm_block = attn_unnorm.split(block_size, dim=-1)
attn_norm = torch.cat([a / rearrange(torch.exp(lcse_block[-1] - m), 'b h s -> b h s 1')
for a, m in zip(attn_unnorm_block, scores_max_block)], dim=-1)
lse = torch.logsumexp(lse_block, dim=-1)
scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
attn_norm = torch.cat([a / rearrange(torch.exp(lse - m), 'b h s -> b h s 1')
for a, m in zip(attn_unnorm_block, cummax_block)], dim=-1)
if query_padding_mask is not None:
attn_norm.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0)
return attn_norm.to(dtype=attn_unnorm.dtype)
......@@ -350,68 +333,79 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
# @pytest.mark.parametrize('seqlen', [97])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# @pytest.mark.parametrize('dropout_p', [0.17])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# if dtype == torch.float16:
# rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
# else: # torch.bfloat16
# rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
# Set smaller batch size so it would trigger num_splits > 1
batch_size = 8
nheads = 4
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
)
output_unpad, sm_lse, S_dmask = flash_attn_unpadded_qkvpacked_func(
qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal
batch_size = 16
nheads = 9
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype,
requires_grad=True)
out, lse, S_dmask = flash_attn_qkvpacked_func(
qkv, dropout_p, return_attn_probs=True, causal=causal
)
output = output_pad_fn(output_unpad)
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
causal=causal).item()
output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal)
output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Actual dropout fraction: {dropout_fraction}')
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
g = torch.randn_like(output)
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
dqkv_ref, = torch.autograd.grad(output_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(output_pt, qkv, g)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, None, None, d, dropout_p > 0.0, causal=causal
)[:, :, :seqlen, :seqlen]
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
None, None, dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
print(f'Actual dropout fraction: {dropout_fraction}')
else:
dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal)
out_pt, attn_pt = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal,
upcast=False, reorder_ops=True)
# v = qkv[:, :, 2].float()
# qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
# if causal:
# causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
# qk.masked_fill_(causal_mask, float('-inf'))
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# p_tmp = torch.softmax(qk / math.sqrt(d), -1)
# p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values
# qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values
# qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values
# qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values
# o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:])
# o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])
# o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])
# o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
g = torch.randn_like(out)
# do_o = (g.float() * out.float()).sum(-1)
# dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
# dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv, = torch.autograd.grad(out, qkv, g)
dqkv_ref, = torch.autograd.grad(out_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(out_pt, qkv, g)
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
......@@ -423,584 +417,411 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
if dropout_p == 0.0:
assert dropout_mask.all()
else:
assert 0.98 <= dropout_fraction / dropout_p <= 1.02
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
assert abs(dropout_fraction - dropout_p) <= 0.01
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
# Error for dK and dV could be a bit higher if we're splitting along seqlen_q dimension
assert (dqkv - dqkv_ref).abs().max().item() <= 4 * (dqkv_pt - dqkv_ref).abs().max().item()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# if dtype == torch.float16:
# rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
# else: # torch.bfloat16
# rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
batch_size = 5
nheads = 6
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype,
requires_grad=True)
query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
(q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv,
output_pad_fn, dq_pad_fn, dkv_pad_fn) = generate_qkv(
x, Wqkv, nheads, query_padding_mask, key_padding_mask, kvpacked=True
)
output_unpad, sm_lse, S_dmask = flash_attn_unpadded_kvpacked_func(
q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, return_attn_probs=True, causal=causal
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
*qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
)
output = output_pad_fn(output_unpad)
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, q, kv[:, :, 0], kv[:, :, 1],
query_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask, key_padding_mask,
causal=causal)
output_ref, attn_ref = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask, causal=causal)
output_pt, attn_pt = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask, causal=causal,
upcast=False, reorder_ops=True)
print(f'Actual dropout fraction: {dropout_fraction}')
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
g = torch.randn_like(output)
dq_unpad, dkv_unpad, = torch.autograd.grad(output, (q_unpad, kv_unpad), g)
dq = dq_pad_fn(dq_unpad)
dkv = dkv_pad_fn(dkv_unpad)
dq_ref, dkv_ref, = torch.autograd.grad(output_ref, (q, kv), g)
dq_pt, dkv_pt = torch.autograd.grad(output_pt, (q, kv), g)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dkv[:, :, 0] - dkv_ref[:, :, 0]).abs().max().item()}')
print(f'dV max diff: {(dkv[:, :, 1] - dkv_ref[:, :, 1]).abs().max().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dkv_pt[:, :, 0] - dkv_ref[:, :, 0]).abs().max().item()}')
print(f'dV Pytorch max diff: {(dkv_pt[:, :, 1] - dkv_ref[:, :, 1]).abs().max().item()}')
out = output_pad_fn(out_unpad)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
)[:, :, :seqlen, :seqlen]
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
key_padding_mask, key_padding_mask, dropout_p > 0.0,
causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
causal=causal).item()
print(f'Actual dropout fraction: {dropout_fraction}')
else:
dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal)
out_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv_unpad, = torch.autograd.grad(out, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
dqkv_ref, = torch.autograd.grad(out_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(out_pt, qkv, g)
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
if dropout_p == 0.0:
assert dropout_mask.all()
else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dkv - dkv_ref).abs().max().item() <= 2 * (dkv_pt - dkv_ref).abs().max().item()
# assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
# assert torch.allclose(dkv, dkv_ref, rtol=rtol, atol=atol)
if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
assert abs(dropout_fraction - dropout_p) <= 0.01
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
@pytest.mark.parametrize('kvpacked', [True, False])
# @pytest.mark.parametrize('kvpacked', [False])
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('mha_type', ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mha"])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked):
if max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# if dtype == torch.float16:
# rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
# else: # torch.bfloat16
# rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v,
output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv(
x, Wqkv, nheads, query_padding_mask, key_padding_mask
)
batch_size = 16
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked:
kv = torch.randn(batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
else:
k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
if kvpacked:
out, lse, S_dmask = flash_attn_kvpacked_func(
q, kv, dropout_p, return_attn_probs=True, causal=causal
)
else:
out, lse, S_dmask = flash_attn_func(
q, k, v, dropout_p, return_attn_probs=True, causal=causal
)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, None, None, d, dropout_p > 0.0, causal=causal
)[:, :, :seqlen_q, :seqlen_k]
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
if kvpacked:
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
k_rep, v_rep = kv_rep.unbind(dim=2)
else:
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(attn_unnorm, q, k_rep, v_rep,
None, None, dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
print(f'Actual dropout fraction: {dropout_fraction}')
else:
dropout_mask = None
output_unpad, sm_lse, S_dmask = flash_attn_unpadded_func(
q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, return_attn_probs=True, causal=causal
)
output = output_pad_fn(output_unpad)
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask, key_padding_mask,
dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask, key_padding_mask,
causal=causal)
output_ref, attn_ref = attention_ref(q, k, v, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask, causal=causal)
output_pt, attn_pt = attention_ref(q, k, v, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask, causal=causal,
upcast=False, reorder_ops=True)
print(f'Actual dropout fraction: {dropout_fraction}')
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
g = torch.randn_like(output)
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output, (q_unpad, k_unpad, v_unpad), g)
dq = dq_pad_fn(dq_unpad)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref(q, kv, None, None, dropout_p, dropout_mask,
causal=causal)
out_pt, attn_pt = attention_kvpacked_ref(q, kv, None, None, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
else:
out_ref, attn_ref = attention_ref(q, k, v, None, None, dropout_p, dropout_mask,
causal=causal)
out_pt, attn_pt = attention_ref(q, k, v, None, None, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if kvpacked:
dq, dkv, = torch.autograd.grad(out, (q, kv), g)
dk, dv = dkv.unbind(2)
dq_ref, dkv_ref, = torch.autograd.grad(out_ref, (q, kv), g)
dk_ref, dv_ref = dkv_ref.unbind(2)
dq_pt, dkv_pt, = torch.autograd.grad(out_pt, (q, kv), g)
dk_pt, dv_pt = dkv_pt.unbind(2)
else:
dq, dk, dv, = torch.autograd.grad(out, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(out_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(out_pt, (q, k, v), g)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}')
print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}')
print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}')
print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}')
print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
if dropout_p == 0.0:
assert dropout_mask.all()
else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
assert abs(dropout_fraction - dropout_p) <= 0.01
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
# assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
# assert torch.allclose(dk, dk_ref, rtol=rtol, atol=atol)
# assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol)
@pytest.mark.skipif(True, reason='Experimental, not being used')
@pytest.mark.parametrize('kvpacked', [True, False])
# @pytest.mark.parametrize('kvpacked', [False])
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('mha_type', ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [512])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype,
kvpacked):
if max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# if dtype == torch.float16:
# rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
# else: # torch.bfloat16
# rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='split')
batch_size0 = batch_size // 4 * 3 # this must match what's in generate_random_padding_mask
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad, cu_seqlens, max_seqlen0, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
)
max_seqlen1 = 128
output_unpad, sm_lse, S_dmask0, S_dmask1 = flash_attn_unpadded_qkvpacked_split_func(
qkv_unpad, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
return_attn_probs=True, causal=causal
)
output = output_pad_fn(output_unpad)
S_dmask0_converted = convert_flash_attn_S_to_softmax(
S_dmask0, key_padding_mask[:batch_size0], key_padding_mask[:batch_size0], d, dropout_p > 0.0, causal=causal
)
S_dmask1_converted = convert_flash_attn_S_to_softmax(
S_dmask1, key_padding_mask[batch_size0:, :max_seqlen1], key_padding_mask[batch_size0:, :max_seqlen1], d, dropout_p > 0.0, causal=causal
)
padding = (S_dmask0_converted.shape[-1] - S_dmask1_converted.shape[-1],
S_dmask0_converted.shape[-2] - S_dmask1_converted.shape[-2])
S_dmask_converted = torch.cat([S_dmask0_converted,
F.pad(S_dmask1_converted, (0, padding[0], 0, padding[1]))], dim=0)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
causal=causal).item()
output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal)
output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Actual dropout fraction: {dropout_fraction}')
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
g = torch.randn_like(output)
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
dqkv_ref, = torch.autograd.grad(output_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(output_pt, qkv, g)
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
if dropout_p == 0.0:
assert dropout_mask.all()
batch_size = 16
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked:
kv = torch.randn(batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v,
output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv(
x, Wqkv, nheads, query_padding_mask, key_padding_mask
)
torch.random.manual_seed(0)
output_unpad_0, sm_lse_0, S_dmask_0 = flash_attn_unpadded_func(
q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, return_attn_probs=True, causal=causal
)
S_dmask_converted_0 = convert_flash_attn_S_to_softmax(
S_dmask_0, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
)
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
g = torch.randn_like(output_unpad_0)
dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0,
(q_unpad, k_unpad, v_unpad), g)
# Parallelizing over seqlen_k makes dq non-deterministic
deterministic_dq = False
# Numerical error if we just do any arithmetic on dq
dq_atol = ((dq_unpad_0 + 0.3 - 0.3) - dq_unpad_0).abs().max().item()
equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol)
for _ in range(10):
torch.random.manual_seed(0)
output_unpad, sm_lse, S_dmask = flash_attn_unpadded_func(
k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype,
requires_grad=True)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode='random')
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='random')
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
if kvpacked:
(q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv,
output_pad_fn, dq_pad_fn, dkv_pad_fn) = generate_qkv(
q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True
)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, return_attn_probs=True, causal=causal
)
else:
(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v,
output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv(
q, k, v, query_padding_mask, key_padding_mask, kvpacked=False
)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, return_attn_probs=True, causal=causal
)
out = output_pad_fn(out_unpad)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
)
assert torch.equal(output_unpad, output_unpad_0)
# sm_lse has some parts that are uninitialized from torch.empty
# assert torch.equal(sm_lse, sm_lse_0)
assert torch.equal(S_dmask_converted, S_dmask_converted_0)
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad,
(q_unpad, k_unpad, v_unpad), g)
assert equal_fn(dq_unpad, dq_unpad_0)
assert torch.equal(dk_unpad, dk_unpad_0)
assert torch.equal(dv_unpad, dv_unpad_0)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='requires multiple GPUs')
def test_flash_attn_multigpu():
seqlen = 256
d = 64
dropout_p = 0.0
causal = False
dtype = torch.float16
device = 'cuda:1'
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
)
output_unpad, sm_lse, S_dmask = flash_attn_unpadded_qkvpacked_func(
qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal
)
output = output_pad_fn(output_unpad)
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
causal=causal).item()
output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal)
output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Actual dropout fraction: {dropout_fraction}')
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
g = torch.randn_like(output)
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
dqkv_ref, = torch.autograd.grad(output_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(output_pt, qkv, g)
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')
)[:, :, :seqlen_q, :seqlen_k]
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
if kvpacked:
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
k_rep, v_rep = kv_rep.unbind(dim=2)
else:
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(attn_unnorm, q, k_rep, v_rep,
query_padding_mask, key_padding_mask,
dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask,
key_padding_mask, causal=causal).item()
print(f'Actual dropout fraction: {dropout_fraction}')
else:
dropout_mask = None
if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask, causal=causal)
out_pt, attn_pt = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
else:
out_ref, attn_ref = attention_ref(q, k, v, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask, causal=causal)
out_pt, attn_pt = attention_ref(q, k, v, query_padding_mask, key_padding_mask,
dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}')
if dropout_p > 0.0:
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if kvpacked:
dq_unpad, dkv_unpad, = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
dq_ref, dkv_ref, = torch.autograd.grad(out_ref, (q, kv), g)
dk_ref, dv_ref = dkv_ref.unbind(2)
dq_pt, dkv_pt, = torch.autograd.grad(out_pt, (q, kv), g)
dk_pt, dv_pt = dkv_pt.unbind(2)
else:
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(out_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(out_pt, (q, k, v), g)
dq = dq_pad_fn(dq_unpad)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}')
print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}')
print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}')
print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}')
print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
if dropout_p == 0.0:
assert dropout_mask.all()
else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
assert abs(dropout_fraction - dropout_p) <= 0.01
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.skipif(flash_attn_func is None, reason='Triton is not installed or is too old')
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
# @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [48])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1023)])
@pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk']))
# @pytest.mark.parametrize('bias_shape', (['1hqk']))
def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_shape):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
@pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
# @pytest.mark.parametrize('seqlen', [193])
# @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
if bias_shape == '1h1k':
bias = torch.randn(1, nheads, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == '1hqk':
bias = torch.randn(1, nheads, seqlen_q, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b11k':
bias = torch.randn(batch_size, 1, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b1qk':
bias = torch.randn(batch_size, 1, seqlen_q, seqlen_k, dtype=torch.float, device=device)
else:
bias = None
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
output = flash_attn_func(q, k, v, bias, causal)
output_ref, attn_ref = attention_ref(q, k, v, bias=bias, causal=causal)
output_pt, attn_pt = attention_ref(q, k, v, bias=bias, causal=causal, upcast=False,
reorder_ops=True)
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
g = torch.randn_like(output)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}')
print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}')
print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}')
print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}')
print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True)
out0, lse0, _ = flash_attn_qkvpacked_func(
qkv, dropout_p, return_attn_probs=True, causal=causal
)
g = torch.randn_like(out0)
dqkv0, = torch.autograd.grad(out0, qkv, g)
for _ in range(200):
torch.random.manual_seed(0)
out, lse, S_dmask = flash_attn_qkvpacked_func(
qkv, dropout_p, return_attn_probs=True, causal=causal
)
assert torch.equal(out, out0)
assert torch.equal(lse, lse0)
# sm_lse has some parts that are uninitialized from torch.empty
# assert torch.equal(sm_lse, sm_lse_0)
@pytest.mark.skipif(flash_attn_func is None, reason='Triton is not installed or is too old')
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203)])
@pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk']))
# @pytest.mark.parametrize('bias_shape', (['b1qk']))
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, bias_shape):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
if bias_shape == '1h1k':
bias = torch.randn(1, nheads, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == '1hqk':
bias = torch.randn(1, nheads, seqlen_q, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b11k':
bias = torch.randn(batch_size, 1, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b1qk':
bias = torch.randn(batch_size, 1, seqlen_q, seqlen_k, dtype=torch.float, device=device)
else:
bias = None
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
output_0 = flash_attn_func(q, k, v, bias, causal)
g = torch.randn_like(output_0)
dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g)
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
deterministic_dq = False
# Numerical error if we just do any arithmetic on dq
dq_atol = ((dq_0 + 0.3 - 0.3) - dq_0).abs().max().item()
equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol)
# Run 10000 times and check that the results don't change
for i in range(10000):
output = flash_attn_func(q, k, v, bias, causal)
output_equal = torch.equal(output, output_0)
if not output_equal: # Printing / computing diff sometimes makes the race condition disappear
print(f'{dtype = }, {causal = }, {d = }, {seqlen_q = }, {seqlen_k = }, {bias_shape = }, {i = }')
print(f'Output max diff: {(output - output_0).abs().max().item()}')
assert torch.equal(output, output_0)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_equal = equal_fn(dq, dq_0)
dk_equal = torch.equal(dk, dk_0)
dv_equal = torch.equal(dv, dv_0)
if not (dq_equal and dk_equal and dv_equal):
print(f'{dtype = }, {causal = }, {d = }, {seqlen_q = }, {seqlen_k = }, {bias_shape = }, {i = }')
print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
assert equal_fn(dq, dq_0)
assert torch.equal(dk, dk_0)
assert torch.equal(dv, dv_0)
if not (is_sm75 and d == 128):
dqkv, = torch.autograd.grad(out, qkv, g)
assert torch.equal(dqkv[:, :, 0], dqkv0[:, :, 0])
assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1])
assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2])
......@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN pip install flash-attn==1.0.9
RUN pip install flash-attn==2.0.0.post1
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN git clone https://github.com/HazyResearch/flash-attention \
&& cd flash-attention && git checkout v1.0.9 \
&& cd flash-attention && git checkout v2.0.0.post1 \
&& cd csrc/fused_softmax && pip install . && cd ../../ \
&& cd csrc/rotary && pip install . && cd ../../ \
&& cd csrc/xentropy && pip install . && cd ../../ \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment