Unverified Commit dfe1a59e authored by Ying Zhang's avatar Ying Zhang Committed by GitHub
Browse files

Add var-seq-len to FA3 fp16 / bf16 fwd (#1072)



* fwd var-seq-len

* fixes

* benchmark

* fixes

---------
Co-authored-by: default avatarTri Dao <tridao@users.noreply.github.com>
parent cb516f85
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
try:
import cudnn
except ImportError:
cudnn = None
from einops import rearrange, repeat
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3
# Need to install triton nightly:
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
try:
from triton_fused_attention import attention as triton_attention
except ImportError:
triton_attention = None
def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
def convert_to_cudnn_type(torch_type):
if torch_type == torch.float16:
return cudnn.data_type.HALF
elif torch_type == torch.bfloat16:
return cudnn.data_type.BFLOAT16
elif torch_type == torch.float32:
return cudnn.data_type.FLOAT
elif torch_type == torch.int32:
return cudnn.data_type.INT32
elif torch_type == torch.int64:
return cudnn.data_type.INT64
else:
raise ValueError("Unsupported tensor data type.")
def cudnn_sdpa_setup(q, k, v, grad, causal=False):
b, nheads, seqlen_q, headdim = q.shape
_, _, seqlen_k, _ = k.shape
assert v.shape == (b, nheads, seqlen_k, headdim)
assert cudnn is not None, 'CUDNN is not available'
q_gpu, k_gpu, v_gpu = q, k, v
o_gpu = torch.empty_like(q_gpu)
stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
graph_forward = cudnn.pygraph(
io_data_type=convert_to_cudnn_type(q.dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
q_forward = graph_forward.tensor_like(q_gpu.detach())
k_forward = graph_forward.tensor_like(k_gpu.detach())
v_forward = graph_forward.tensor_like(v_gpu.detach())
o_forward, stats_forward = graph_forward.sdpa(
name="sdpa",
q=q_forward,
k=k_forward,
v=v_forward,
is_inference=False,
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal,
)
o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT)
graph_forward.validate()
graph_forward.build_operation_graph()
graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_forward.check_support()
graph_forward.build_plans()
variant_pack_forward = {
q_forward: q_gpu,
k_forward: k_gpu,
v_forward: v_gpu,
o_forward: o_gpu,
stats_forward: stats_gpu,
}
dQ_gpu = torch.empty_like(q_gpu)
dK_gpu = torch.empty_like(k_gpu)
dV_gpu = torch.empty_like(v_gpu)
dO_gpu = grad
graph_backward = cudnn.pygraph(
io_data_type=cudnn.data_type.HALF,
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
q_backward = graph_backward.tensor_like(q_gpu.detach())
k_backward = graph_backward.tensor_like(k_gpu.detach())
v_backward = graph_backward.tensor_like(v_gpu.detach())
o_backward = graph_backward.tensor_like(o_gpu.detach())
dO_backward = graph_backward.tensor_like(dO_gpu.detach())
stats_backward = graph_backward.tensor_like(stats_gpu.detach())
dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
name="sdpa_backward",
q=q_backward,
k=k_backward,
v=v_backward,
o=o_backward,
dO=dO_backward,
stats=stats_backward,
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal,
)
dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())
graph_backward.validate()
graph_backward.build_operation_graph()
graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_backward.check_support()
graph_backward.build_plans()
variant_pack_backward = {
q_backward: q_gpu,
k_backward: k_gpu,
v_backward: v_gpu,
o_backward: o_gpu,
dO_backward: dO_gpu,
stats_backward: stats_gpu,
dQ_backward: dQ_gpu,
dK_backward: dK_gpu,
dV_backward: dV_gpu,
}
workspace = torch.empty(
max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()),
device="cuda", dtype=torch.uint8
)
def run_fwd(*args, **kwargs):
graph_forward.execute(variant_pack_forward, workspace)
return o_gpu, stats_gpu
def run_bwd(*args, **kwargs):
graph_backward.execute(variant_pack_backward, workspace)
return dQ_gpu, dK_gpu, dV_gpu
return run_fwd, run_bwd
torch.manual_seed(0)
repeats = 100
dropout_p = 0.0
causal = False
dtype = torch.float16
device = 'cuda'
verbose = False
batch_size = 2
# seqlen = 2048
seqlen = 8192
# seqlen = 4096
# seqlen = 2047
dim = 2048
# headdim = 128
# headdim = 64
headdim = 256
# for mode in ['fwd', 'bwd']:
for mode in ['fwd']:
for headdim in [64, 128, 256]:
# for headdim in [128]:
for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
# for seqlen in [8192]:
nheads = dim // headdim
# nheads = 24
# headdim = 64
# batch_size = 64
# seqlen = 512
# nheads = 8
# headdim = 128
nheads_kv = nheads
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
grad_t = grad.transpose(1, 2).contiguous()
bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
for causal in [False, True]:
# for causal in [True]:
print(f"\n### {headdim = }, {seqlen = }, {causal = } ###")
if headdim <= 128 and cudnn is not None:
cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal)
f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
_, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
if mode == 'bwd':
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
if headdim <= 128:
if triton_attention is not None:
if mode == 'fwd':
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
_, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
# TODO: fix Triton numeric errors.
# if mode == 'bwd':
# dv, v_t.grad = v_t.grad.clone(), None
# dk, k_t.grad = k_t.grad.clone(), None
# dq, q_t.grad = q_t.grad.clone(), None
# torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
# torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
# torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
if cudnn is not None:
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
if mode == 'fwd':
_, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
else:
cudnn_sdpa_fwd()
_, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
dq, dk, dv = cudnn_sdpa_bwd()
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
# pytorch_profiler(cudnn_sdpa, backward=False)
if headdim == 128 or mode == 'fwd':
time.sleep(1)
_, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
time.sleep(1)
_, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
if mode == 'bwd':
dv, v.grad = v.grad.clone(), None
dk, k.grad = k.grad.clone(), None
dq, q.grad = q.grad.clone(), None
torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
if headdim <= 128:
if triton_attention is not None:
print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
if cudnn is not None:
print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
if headdim == 128 or mode == 'fwd':
print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
\ No newline at end of file
...@@ -17,20 +17,15 @@ namespace flash { ...@@ -17,20 +17,15 @@ namespace flash {
using namespace cute; using namespace cute;
// template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_> // template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
template <typename Ktraits> template <typename Ktraits, typename Seqlen_traits>
struct CollectiveEpilogueFwd { struct CollectiveEpilogueFwd {
using Element = typename Ktraits::Element; using Element = typename Ktraits::Element;
static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN; static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kHeadDim = Ktraits::kHeadDim; static constexpr int kHeadDim = Ktraits::kHeadDim;
// using Element = Element_;
// static constexpr int kBlockM = kBlockM_;
// static constexpr int kBlockN = kBlockN_;
// static constexpr int kHeadDim = kHeadDim_;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>; using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
// static constexpr int kNWarps = kNWarps_;
static constexpr int kNWarps = Ktraits::kNWarps; static constexpr int kNWarps = Ktraits::kNWarps;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr bool Is_WS = kNWarps >= 12; static constexpr bool Is_WS = kNWarps >= 12;
...@@ -38,20 +33,6 @@ struct CollectiveEpilogueFwd { ...@@ -38,20 +33,6 @@ struct CollectiveEpilogueFwd {
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumCopyThreads; static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
// These are for storing the output tensor without TMA (e.g., for setting output to zero)
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
static constexpr int kGmemThreadsPerRow = kHeadDim / kGmemElemsPerLoad;
static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element, using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
...@@ -59,52 +40,72 @@ struct CollectiveEpilogueFwd { ...@@ -59,52 +40,72 @@ struct CollectiveEpilogueFwd {
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>; using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>; using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch) using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen_q, head, batch)
using TMA_O = decltype(make_tma_copy( using TMA_O = decltype(make_tma_copy(
GmemTiledCopyOTMA{}, GmemTiledCopyOTMA{},
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), repeat_like(StrideO{}, int32_t(0)), StrideO{}), make_tensor(
make_gmem_ptr(static_cast<Element*>(nullptr)),
typename Seqlen_traits::ShapeT{},
typename Seqlen_traits::StrideT{}
),
SmemLayoutO{}, SmemLayoutO{},
select<0, 2>(TileShape_MNK{}), select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for O _1{})); // no mcast for O
// These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len)
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
static_assert(kHeadDim % kNumVecElem == 0);
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
using TiledCopyOThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyOValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}),
LayoutRight{}));
using TiledCopyO = decltype(make_tiled_copy(
TiledCopyOAtom{},
TiledCopyOThrLayout{}, // Thr layout
TiledCopyOValLayout{} // Val layout
));
// Host side kernel arguments // Host side kernel arguments
struct Arguments { struct Arguments {
Element* ptr_O; Element* ptr_O;
ShapeO const shape_O; typename Seqlen_traits::LayoutT const layout_O;
StrideO const stride_O;
float* ptr_LSE; float* ptr_LSE;
StrideLSE const stride_LSE; typename Seqlen_traits::LayoutLseT const layout_LSE;
}; };
// Device side kernel params // Device side kernel params
struct Params { struct Params {
Element* ptr_O; Element* ptr_O;
ShapeO const shape_O; typename Seqlen_traits::LayoutT const layout_O;
StrideO const stride_O;
float* ptr_LSE; float* ptr_LSE;
StrideLSE const stride_LSE; typename Seqlen_traits::LayoutLseT const layout_LSE;
TMA_O tma_store_O; TMA_O tma_store_O;
}; };
static Params static Params
to_underlying_arguments(Arguments const& args) { to_underlying_arguments(Arguments const& args) {
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.layout_O);
TMA_O tma_store_O = make_tma_copy( TMA_O tma_store_O = make_tma_copy(
GmemTiledCopyOTMA{}, GmemTiledCopyOTMA{},
mO, mO,
SmemLayoutO{}, SmemLayoutO{},
select<0, 2>(TileShape_MNK{}), select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for O _1{}); // no mcast for O
return {args.ptr_O, args.shape_O, args.stride_O, args.ptr_LSE, args.stride_LSE, tma_store_O}; return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O};
} }
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& epilogue_params) { static void prefetch_tma_descriptors(Params const& epilogue_params) {
cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor()); if constexpr (!Seqlen_traits::kUseVarSeqLen) {
cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
}
} }
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma> template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
...@@ -115,7 +116,8 @@ struct CollectiveEpilogueFwd { ...@@ -115,7 +116,8 @@ struct CollectiveEpilogueFwd {
SharedStorage& shared_storage, SharedStorage& shared_storage,
TiledMma tiled_mma, TiledMma tiled_mma,
int thread_idx, int thread_idx,
cute::tuple<int32_t, int32_t, int32_t> const& block_coord cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
const Seqlen_traits& seqlen_traits_q
) { ) {
auto [m_block, bidh, bidb] = block_coord; auto [m_block, bidh, bidb] = block_coord;
...@@ -134,16 +136,9 @@ struct CollectiveEpilogueFwd { ...@@ -134,16 +136,9 @@ struct CollectiveEpilogueFwd {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.shape_O); Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{}); mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
auto thread_mma = tiled_mma.get_thread_slice(thread_idx); auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
...@@ -156,19 +151,23 @@ struct CollectiveEpilogueFwd { ...@@ -156,19 +151,23 @@ struct CollectiveEpilogueFwd {
#pragma unroll #pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi)); const int row = get<0>(taccOcO_row(mi));
if (row < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(row) = lse(mi); } if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); }
} }
} }
if (cutlass::canonical_warp_idx_sync() == kNWarps - 1) { int write_warp_idx = kNWarps - 1;
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); cutlass::arch::NamedBarrier::sync(
int const lane_predicate = cute::elect_one_sync(); NumMmaThreads + cutlass::NumThreadsPerWarp,
if (lane_predicate) { cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
cute::copy(epilogue_params.tma_store_O, tOsO, tOgO); );
tma_store_arrive();
}
} }
TiledCopyO gmem_tiled_copy_O;
flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O,
epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO,
m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
);
} }
CUTLASS_DEVICE void CUTLASS_DEVICE void
...@@ -177,20 +176,25 @@ struct CollectiveEpilogueFwd { ...@@ -177,20 +176,25 @@ struct CollectiveEpilogueFwd {
} }
// Write 0 to output and -inf to LSE // Write 0 to output and -inf to LSE
template<typename SharedStorage>
CUTLASS_DEVICE void CUTLASS_DEVICE void
store_zero( store_zero(
Params const& epilogue_params, Params const& epilogue_params,
int thread_idx, SharedStorage& shared_storage,
cute::tuple<int32_t, int32_t, int32_t> const& block_coord int thread_idx,
) { cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
const Seqlen_traits& seqlen_traits_q
) {
auto [m_block, bidh, bidb] = block_coord; auto [m_block, bidh, bidb] = block_coord;
Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.shape_O, epilogue_params.stride_O); Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O);
Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) Tensor gO = seqlen_traits_q.get_local_tile_tensor(
auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O); mO, select<0, 2>(TileShape_MNK{}), bidh, bidb
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE); )(_, _, m_block); // (M, K)
Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block)); Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
GmemTiledCopyO gmem_tiled_copy_O; mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
TiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_fragment_like(tOgO); Tensor tOrO = make_fragment_like(tOgO);
...@@ -201,13 +205,13 @@ struct CollectiveEpilogueFwd { ...@@ -201,13 +205,13 @@ struct CollectiveEpilogueFwd {
Tensor tOcO = gmem_thr_copy_O.partition_D(cO); Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO))); Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
#pragma unroll #pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.shape_O); } for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.shape_O) - m_block * kBlockM gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
); );
static_assert(kBlockM <= NumMmaThreads); static_assert(kBlockM <= NumMmaThreads);
if (thread_idx < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; } if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
} }
}; };
......
...@@ -57,7 +57,7 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -57,7 +57,7 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ softmax_lseaccum_ptr; void * __restrict__ softmax_lseaccum_ptr;
// The dimensions. // The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q, total_k;
// The scaling factors for the kernel. // The scaling factors for the kernel.
float scale_softmax; float scale_softmax;
...@@ -128,6 +128,8 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -128,6 +128,8 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ alibi_slopes_ptr; void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride; index_t alibi_slopes_batch_stride;
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
int * __restrict__ tile_count_semaphore; int * __restrict__ tile_count_semaphore;
}; };
......
...@@ -43,7 +43,8 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -43,7 +43,8 @@ void set_params_fprop(Flash_fwd_params &params,
float softmax_scale, float softmax_scale,
int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
bool seqlenq_ngroups_swapped=false) { bool seqlenq_ngroups_swapped=false,
bool unpadded_lse=false) {
// Reset the parameters // Reset the parameters
params = {}; params = {};
...@@ -81,6 +82,11 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -81,6 +82,11 @@ void set_params_fprop(Flash_fwd_params &params,
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d); params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
params.seqused_k = static_cast<int *>(seqused_k); params.seqused_k = static_cast<int *>(seqused_k);
TORCH_CHECK(
bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k),
"cu_seqlens_q and cu_seqlens_k must be both null or non-null"
);
// P = softmax(QK^T) // P = softmax(QK^T)
params.p_ptr = p_d; params.p_ptr = p_d;
...@@ -139,6 +145,8 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -139,6 +145,8 @@ void set_params_fprop(Flash_fwd_params &params,
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
#endif #endif
params.unpadded_lse = unpadded_lse;
} }
void set_params_dgrad(Flash_bwd_params &params, void set_params_dgrad(Flash_bwd_params &params,
...@@ -372,6 +380,154 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -372,6 +380,154 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
} }
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
int max_seqlen_q,
const int max_seqlen_k,
const float softmax_scale,
bool is_causal) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
const auto sizes = q.sizes();
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int num_heads_k = k.size(1);
int window_size_left = -1;
int window_size_right = -1;
if (is_causal) { window_size_right = 0; }
void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
const int total_q = q.sizes()[0];
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
if (seqused_k.has_value()){
auto seqused_k_ = seqused_k.value();
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
CHECK_SHAPE(seqused_k_, batch_size);
}
at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) {
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
q_padded = q;
k_padded = k;
v_padded = v;
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
max_seqlen_q, max_seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q_padded, k_padded, v_padded, out,
cu_seqlens_q_d,
cu_seqlens_k.data_ptr(),
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
/*p_d=*/nullptr,
softmax_lse.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
window_size_left,
window_size_right,
/*seqlenq_ngroups_swapped=*/false,
/*unpadded_lse=*/true);
params.total_q = total_q;
params.total_k = total_k;
if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
at::Tensor out_padded = out;
if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
}
void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
// FP16_SWITCH(!params.is_bf16, [&] { // FP16_SWITCH(!params.is_bf16, [&] {
// HEADDIM_SWITCH(params.d, [&] { // HEADDIM_SWITCH(params.d, [&] {
...@@ -577,4 +733,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -577,4 +733,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention"; m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass"); m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass"); m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
} }
...@@ -57,6 +57,83 @@ def _flash_attn_backward( ...@@ -57,6 +57,83 @@ def _flash_attn_backward(
) )
return dq, dk, dv, softmax_d return dq, dk, dv, softmax_d
def _flash_attn_varlen_forward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
q,
k,
v,
None,
cu_seqlens_q,
cu_seqlens_k,
None,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
return out, q, k, v, out_padded, softmax_lse
def _flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
(
dq,
dk,
dv,
softmax_d,
) = _get_fa_module().varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -105,6 +182,71 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -105,6 +182,71 @@ class FlashAttnFunc(torch.autograd.Function):
return dq, dk, dv, None, None return dq, dk, dv, None, None
class FlashAttnVarlenFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal=causal,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k
)
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out, softmax_lse
@staticmethod
def backward(ctx, dout, *args):
# TODO: Uncomment these when var-seq-len is supported in bwd kernel.
# q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
# dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
# _flash_attn_varlen_backward(
# dout,
# q,
# k,
# v,
# out,
# softmax_lse,
# dq,
# dk,
# dv,
# cu_seqlens_q,
# cu_seqlens_k,
# ctx.max_seqlen_q,
# ctx.max_seqlen_k,
# ctx.softmax_scale,
# ctx.causal,
# )
# dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
# dk = dk[..., : dout.shape[-1]]
# dv = dv[..., : dout.shape[-1]]
# return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
return None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_func( def flash_attn_func(
q, q,
k, k,
...@@ -167,3 +309,62 @@ def flash_attn_func( ...@@ -167,3 +309,62 @@ def flash_attn_func(
softmax_scale, softmax_scale,
causal, causal,
) )
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale=None,
causal=False,
):
"""
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
return FlashAttnVarlenFunc.apply(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
)
...@@ -24,11 +24,12 @@ namespace flash { ...@@ -24,11 +24,12 @@ namespace flash {
using namespace cute; using namespace cute;
template <typename Ktraits, bool Is_causal, typename TileScheduler> template <typename Ktraits, bool Is_causal, typename TileScheduler, typename Seqlen_traits>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal>::Params const mainloop_params, compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>::Params const mainloop_params,
CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits>::Params const epilogue_params, CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits>::Params const epilogue_params,
CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k
) { ) {
using Element = typename Ktraits::Element; using Element = typename Ktraits::Element;
...@@ -46,8 +47,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, ...@@ -46,8 +47,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
// static constexpr int kBlockN = Ktraits::kBlockN; // static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim; // constexpr int kHeadDim = Ktraits::kHeadDim;
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal>; using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>;
using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits>; using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline; using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
...@@ -115,14 +116,21 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, ...@@ -115,14 +116,21 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
auto block_coord = work_tile_info.get_block_coord(scheduler_params); auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord; auto [m_block, bidh, bidb] = block_coord;
int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block); seqlen_traits_q.init(bidb);
seqlen_traits_k.init(bidb);
if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
continue;
}
int n_block_max = collective_mainloop.get_n_block_max(
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
if (Is_causal && n_block_max <= 0) { if (Is_causal && n_block_max <= 0) {
scheduler.prefetch_next_work(scheduler_params, work_tile_info); scheduler.prefetch_next_work(scheduler_params, work_tile_info);
scheduler.broadcast_next_work(work_tile_info); scheduler.broadcast_next_work(work_tile_info);
continue; continue;
} }
collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
seqlen_traits_q, seqlen_traits_k);
++work_idx; ++work_idx;
} }
collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v); collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
...@@ -154,17 +162,24 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, ...@@ -154,17 +162,24 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
auto block_coord = work_tile_info.get_block_coord(scheduler_params); auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord; auto [m_block, bidh, bidb] = block_coord;
int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block); seqlen_traits_q.init(bidb);
seqlen_traits_k.init(bidb);
if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
continue;
}
int n_block_max = collective_mainloop.get_n_block_max(
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord); collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
continue; continue;
} }
collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage); tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage,
seqlen_traits_q, seqlen_traits_k);
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage); // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
threadIdx.x - NumCopyThreads, block_coord); threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
++work_idx; ++work_idx;
} }
......
...@@ -14,41 +14,61 @@ ...@@ -14,41 +14,61 @@
#include "tile_scheduler.hpp" #include "tile_scheduler.hpp"
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
#include "kernel_traits.h" #include "kernel_traits.h"
#include "seq_len.h"
#include "utils.h" #include "utils.h"
template<typename Kernel_traits, bool Is_causal> template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK; using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK; using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{}); // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal>; using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Seqlen_traits>;
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits>; using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits>;
using Scheduler = std::conditional_t<!Is_causal, using Scheduler = std::conditional_t<
flash::StaticPersistentTileScheduler, Seqlen_traits::kUseVarSeqLen,
flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>>; flash::SingleTileScheduler,
// flash::SingleTileScheduler>; std::conditional_t<!Is_causal,
flash::StaticPersistentTileScheduler,
flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>
>>;
// using Scheduler = flash::SingleTileScheduler;
Seqlen_traits seqlen_traits_q(
params.total_q, params.seqlen_q, params.cu_seqlens_q);
Seqlen_traits seqlen_traits_k(
params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k);
typename CollectiveMainloop::Params mainloop_params = typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({ CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(params.q_ptr), static_cast<Element const*>(params.q_ptr),
{params.seqlen_q, params.d, params.h, params.b}, // shape_Q seqlen_traits_q.get_gmem_layout(
{params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride}, // stride_Q params.seqlen_q, params.d, params.h, params.b,
params.q_row_stride, params.q_head_stride, params.q_batch_stride
), // layout_Q
static_cast<Element const*>(params.k_ptr), static_cast<Element const*>(params.k_ptr),
{params.seqlen_k, params.d, params.h_k, params.b}, // shape_K seqlen_traits_k.get_gmem_layout(
{params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride}, // stride_K params.seqlen_k, params.d, params.h_k, params.b,
params.k_row_stride, params.k_head_stride, params.k_batch_stride
), // layout_K
static_cast<Element const*>(params.v_ptr), static_cast<Element const*>(params.v_ptr),
{params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride}, // stride_V seqlen_traits_k.get_gmem_layout(
params.seqlen_k, params.d, params.h_k, params.b,
params.v_row_stride, params.v_head_stride, params.v_batch_stride
), // layout_V
params.scale_softmax_log2 params.scale_softmax_log2
}); });
typename CollectiveEpilogue::Params epilogue_params = typename CollectiveEpilogue::Params epilogue_params =
CollectiveEpilogue::to_underlying_arguments({ CollectiveEpilogue::to_underlying_arguments({
static_cast<Element*>(params.o_ptr), static_cast<Element*>(params.o_ptr),
{params.seqlen_q, params.d, params.h, params.b}, // shape_O seqlen_traits_q.get_gmem_layout(
{params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride}, // stride_O params.seqlen_q, params.d, params.h, params.b,
params.o_row_stride, params.o_head_stride, params.o_batch_stride
), // layout_O
static_cast<float*>(params.softmax_lse_ptr), static_cast<float*>(params.softmax_lse_ptr),
{_1{}, params.seqlen_q, params.h * params.seqlen_q}, // stride_LSE seqlen_traits_q.get_lse_gmem_layout(
params.seqlen_q, params.h, params.b
) // layout_LSE
}); });
int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
...@@ -58,7 +78,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -58,7 +78,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// Get the ptr to kernel function. // Get the ptr to kernel function.
void *kernel; void *kernel;
kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler>; kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage); int smem_size = sizeof(typename Kernel_traits::SharedStorage);
// int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q)); // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
// int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k)); // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
...@@ -81,7 +101,9 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -81,7 +101,9 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
dim3 block_dims(ctaSize); dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params); cutlass::launch_kernel_on_cluster(
launch_params, kernel, mainloop_params, epilogue_params,
scheduler_params, seqlen_traits_q, seqlen_traits_k);
CHECK_CUDA_KERNEL_LAUNCH(); CHECK_CUDA_KERNEL_LAUNCH();
} }
...@@ -89,7 +111,12 @@ template<typename T> ...@@ -89,7 +111,12 @@ template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64; constexpr static int Headdim = 64;
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>, Is_causal>(params, stream); SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>,
Is_causal, Seqlen_traits
>(params, stream);
});
}); });
} }
...@@ -97,9 +124,14 @@ template<typename T> ...@@ -97,9 +124,14 @@ template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128; constexpr static int Headdim = 128;
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// Only use Cluster if number of tiles along seqlen_q is even SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream); BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
Is_causal, Seqlen_traits
>(params, stream);
});
}); });
}); });
} }
...@@ -108,9 +140,14 @@ template<typename T> ...@@ -108,9 +140,14 @@ template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256; constexpr static int Headdim = 256;
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// Only use Cluster if number of tiles along seqlen_q is even SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] { // Only use Cluster if number of tiles along seqlen_q is even
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream); BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>,
Is_causal, Seqlen_traits
>(params, stream);
});
}); });
}); });
} }
...@@ -21,7 +21,7 @@ namespace flash { ...@@ -21,7 +21,7 @@ namespace flash {
using namespace cute; using namespace cute;
template <typename Ktraits, bool Is_causal> template <typename Ktraits, bool Is_causal, typename Seqlen_traits>
struct CollectiveMainloopFwd { struct CollectiveMainloopFwd {
using Element = typename Ktraits::Element; using Element = typename Ktraits::Element;
...@@ -64,19 +64,24 @@ struct CollectiveMainloopFwd { ...@@ -64,19 +64,24 @@ struct CollectiveMainloopFwd {
// decltype(tile_to_shape(SmemLayoutAtomVTMA{}, // decltype(tile_to_shape(SmemLayoutAtomVTMA{},
// make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{}))); // make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
using TMA_Q = decltype(make_tma_copy( using TMA_Q = decltype(make_tma_copy(
GmemTiledCopyQ{}, GmemTiledCopyQ{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
typename Seqlen_traits::StrideT{}
),
SmemLayoutQ{}, SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}), select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for Q _1{})); // no mcast for Q
using TMA_KV = decltype(make_tma_copy( using TMA_KV = decltype(make_tma_copy(
GmemTiledCopyKV{}, GmemTiledCopyKV{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
typename Seqlen_traits::StrideT{}
),
take<0, 2>(SmemLayoutK{}), take<0, 2>(SmemLayoutK{}),
select<1, 2>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
...@@ -95,20 +100,19 @@ struct CollectiveMainloopFwd { ...@@ -95,20 +100,19 @@ struct CollectiveMainloopFwd {
// Host side kernel arguments // Host side kernel arguments
struct Arguments { struct Arguments {
Element const* ptr_Q; Element const* ptr_Q;
ShapeQKV const shape_Q; typename Seqlen_traits::LayoutT layout_Q;
StrideQKV const stride_Q;
Element const* ptr_K; Element const* ptr_K;
ShapeQKV const shape_K; typename Seqlen_traits::LayoutT layout_K;
StrideQKV const stride_K;
Element const* ptr_V; Element const* ptr_V;
StrideQKV const stride_V; typename Seqlen_traits::LayoutT layout_V;
float const softmax_scale_log2; float const softmax_scale_log2;
}; };
// Device side kernel params // Device side kernel params
struct Params { struct Params {
ShapeQKV const shape_Q; typename Seqlen_traits::LayoutT layout_Q;
ShapeQKV const shape_K; typename Seqlen_traits::LayoutT layout_K;
typename Seqlen_traits::LayoutT layout_V;
cutlass::FastDivmod qhead_per_khead_divmod; cutlass::FastDivmod qhead_per_khead_divmod;
TMA_Q tma_load_Q; TMA_Q tma_load_Q;
TMA_KV tma_load_K, tma_load_V; TMA_KV tma_load_K, tma_load_V;
...@@ -118,29 +122,29 @@ struct CollectiveMainloopFwd { ...@@ -118,29 +122,29 @@ struct CollectiveMainloopFwd {
static Params static Params
to_underlying_arguments(Arguments const& args) { to_underlying_arguments(Arguments const& args) {
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
TMA_Q tma_load_Q = make_tma_copy( TMA_Q tma_load_Q = make_tma_copy(
GmemTiledCopyQ{}, GmemTiledCopyQ{},
mQ, mQ,
SmemLayoutQ{}, SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}), select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for Q _1{}); // no mcast for Q
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
TMA_KV tma_load_K = make_tma_copy( TMA_KV tma_load_K = make_tma_copy(
GmemTiledCopyKV{}, GmemTiledCopyKV{},
mK, mK,
SmemLayoutK{}(_, _, _0{}), SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
TMA_KV tma_load_V = make_tma_copy( TMA_KV tma_load_V = make_tma_copy(
GmemTiledCopyKV{}, GmemTiledCopyKV{},
mV, mV,
SmemLayoutV{}(_, _, _0{}), SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {args.shape_Q, args.shape_K, return {args.layout_Q, args.layout_K, args.layout_V,
cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
tma_load_Q, tma_load_K, tma_load_V, tma_load_Q, tma_load_K, tma_load_V,
args.softmax_scale_log2}; args.softmax_scale_log2};
} }
...@@ -154,11 +158,15 @@ struct CollectiveMainloopFwd { ...@@ -154,11 +158,15 @@ struct CollectiveMainloopFwd {
} }
CUTLASS_DEVICE CUTLASS_DEVICE
int get_n_block_max(Params const& mainloop_params, int m_block) { int get_n_block_max(
Params const& mainloop_params, int m_block,
const Seqlen_traits& seqlen_traits_q,
const Seqlen_traits& seqlen_traits_k
) {
static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{});
int const seqlen_q = get<0>(mainloop_params.shape_Q); int const seqlen_q = seqlen_traits_q.actual_seq_len;
int const seqlen_k = get<0>(mainloop_params.shape_K); int const seqlen_k = seqlen_traits_k.actual_seq_len;
int n_block_max = cute::ceil_div(seqlen_k, kBlockN); int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
if constexpr (Is_causal) { if constexpr (Is_causal) {
n_block_max = std::min(n_block_max, n_block_max = std::min(n_block_max,
...@@ -179,16 +187,18 @@ struct CollectiveMainloopFwd { ...@@ -179,16 +187,18 @@ struct CollectiveMainloopFwd {
typename Scheduler::Params const& scheduler_params, typename Scheduler::Params const& scheduler_params,
typename Scheduler::WorkTileInfo& work_tile_info, typename Scheduler::WorkTileInfo& work_tile_info,
cute::tuple<int32_t, int32_t, int32_t> block_coord, cute::tuple<int32_t, int32_t, int32_t> block_coord,
int work_idx int work_idx,
const Seqlen_traits& seqlen_traits_q,
const Seqlen_traits& seqlen_traits_k
) { ) {
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.shape_Q); Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K); Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K); Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
auto [m_block, bidh, bidb] = block_coord; auto [m_block, bidh, bidb] = block_coord;
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
...@@ -197,9 +207,12 @@ struct CollectiveMainloopFwd { ...@@ -197,9 +207,12 @@ struct CollectiveMainloopFwd {
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K)
Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gK = seqlen_traits_k.get_local_tile_tensor(
mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
Tensor gV = seqlen_traits_k.get_local_tile_tensor(
mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
...@@ -218,7 +231,7 @@ struct CollectiveMainloopFwd { ...@@ -218,7 +231,7 @@ struct CollectiveMainloopFwd {
} }
} }
int n_block_max = get_n_block_max(mainloop_params, m_block); int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
int lane_predicate = cute::elect_one_sync(); int lane_predicate = cute::elect_one_sync();
...@@ -331,7 +344,9 @@ struct CollectiveMainloopFwd { ...@@ -331,7 +344,9 @@ struct CollectiveMainloopFwd {
int thread_idx, int thread_idx,
int work_idx, int work_idx,
int m_block, int m_block,
SharedStorage& shared_storage SharedStorage& shared_storage,
const Seqlen_traits& seqlen_traits_q,
const Seqlen_traits& seqlen_traits_k
) { ) {
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident."); static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
...@@ -360,8 +375,8 @@ struct CollectiveMainloopFwd { ...@@ -360,8 +375,8 @@ struct CollectiveMainloopFwd {
}; };
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
int const seqlen_q = get<0>(mainloop_params.shape_Q); int const seqlen_q = seqlen_traits_q.actual_seq_len;
int const seqlen_k = get<0>(mainloop_params.shape_K); int const seqlen_k = seqlen_traits_k.actual_seq_len;
int n_block = n_block_count - 1; int n_block = n_block_count - 1;
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2)); cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
...@@ -483,4 +498,3 @@ struct CollectiveMainloopFwd { ...@@ -483,4 +498,3 @@ struct CollectiveMainloopFwd {
}; };
} // namespace flash } // namespace flash
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include <cute/layout.hpp>
namespace flash {
static constexpr int kMaxTileSize = 128;
template <bool UseVarSeqLen> class SeqLenTraits {
public:
// Total number of queries / keys. Unpadded.
int sum_s = 0;
// seq len offsets.
int *cu_seq_len = nullptr;
// actual seq len array.
int *seq_used = nullptr;
// seq len of the current batch.
int actual_seq_len = -1;
// Whether this is for fixed-seq-len or var-seq-len.
static constexpr bool kUseVarSeqLen = UseVarSeqLen;
using ShapeT = std::conditional_t<
UseVarSeqLen,
cute::Shape<int32_t, int32_t, int32_t>,
cute::Shape<int32_t, int32_t, int32_t, int32_t>
>;
using StrideT = std::conditional_t<
UseVarSeqLen,
cute::Shape<int64_t, _1, int64_t>,
cute::Shape<int64_t, _1, int64_t, int64_t>
>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeLseT = std::conditional_t<
UseVarSeqLen,
cute::Shape<int32_t, int32_t>,
cute::Shape<int32_t, int32_t, int32_t>
>;
using StrideLseT = std::conditional_t<
UseVarSeqLen,
cute::Shape<int64_t, _1>,
cute::Shape<int64_t, int64_t, _1>
>;
using LayoutLseT = cute::Layout<ShapeLseT, StrideLseT>;
CUTLASS_HOST SeqLenTraits() {}
CUTLASS_HOST SeqLenTraits(
int sum_s, int max_seq_len, int *cu_seq_len = nullptr, int *seq_used = nullptr):
sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used), actual_seq_len(max_seq_len) {}
// Returns the layout of a tensor in MKHB format in global memory.
// padded: only useful for var-seq-len for dq_accum and softmax_d.
CUTLASS_HOST_DEVICE auto get_gmem_layout(
int m, int k, int h, int b,
int64_t m_stride, int64_t h_stride, int64_t b_stride,
bool padded = false) const {
static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen.");
return make_layout(make_shape(m, k, h, b),
make_stride(m_stride, cute::_1{}, h_stride, b_stride));
}
// Returns the layout of a tensor in MKHB format in global memory.
// padded: only useful for var-seq-len for dq_accum and softmax_d.
CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(
int m, int h, int b, bool padded = false) const {
static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen.");
return make_layout(make_shape(b, h, m),
make_stride(int64_t(h * m), int64_t(m), cute::_1()));
}
CUTLASS_DEVICE void init(int bidb) {}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(
const MTensor &m_tensor, const Shape &tile_shape,
int bidh, int bidb, bool padded = false) const {
auto g_tensor = local_tile(
m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{}));
return g_tensor;
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_lse_local_tile_tensor(
const MTensor &m_tensor, const Shape &tile_shape,
int bidh, int bidb, bool padded = false) const {
auto g_tensor = local_tile(m_tensor(bidb, bidh, _), tile_shape, make_coord(_));
return g_tensor;
}
};
using FixedSeqLenTraits = SeqLenTraits<false>;
using VarSeqLenTraits = SeqLenTraits<true>;
// Returns the static layout of a var-seq-len tensor in global memory based on
// max_seq_len and max_batch_size.
// padded: only useful for var-seq-len for dq_accum and softmax_d.
// When padded is True, use B_M + kMaxTileSize * B as the total B_M.
template <>
CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout(
int m, int k, int h, int b,
int64_t m_stride, int64_t h_stride, int64_t b_stride,
bool padded) const {
return make_layout(
make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h),
make_stride(m_stride, cute::_1{}, h_stride));
}
// padded: only useful for var-seq-len for dq_accum and softmax_d.
// When padded is True, use B_M + kMaxTileSize * B as the total B_M.
template <>
CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_lse_gmem_layout(
int m, int h, int b, bool padded) const {
return make_layout(
make_shape(h, sum_s + (padded ? kMaxTileSize * b : 0)),
make_stride(int64_t(sum_s + (padded ? kMaxTileSize * b : 0)), cute::_1()));
}
template <>
CUTLASS_DEVICE void VarSeqLenTraits::init(int bidb) {
actual_seq_len =
seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]);
}
template <>
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto VarSeqLenTraits::get_local_tile_tensor(
const MTensor &m_tensor, const Shape &tile_shape,
int bidh, int bidb, bool padded) const {
auto g_offset = local_tile(
m_tensor(_, _, bidh),
cute::make_shape(1, get<1>(tile_shape)),
make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{}));
auto g_sequence = make_tensor(
g_offset.data(),
make_layout(
cute::make_shape(actual_seq_len, get<1>(tile_shape)),
g_offset.stride()
));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
return g_tensor;
}
template <>
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto VarSeqLenTraits::get_lse_local_tile_tensor(
const MTensor &m_tensor, const Shape &tile_shape,
int bidh, int bidb, bool padded) const {
auto g_offset = local_tile(
m_tensor(bidh, _), cute::make_shape(_1{}),
make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0)));
auto g_sequence = make_tensor(
g_offset.data(),
make_layout(cute::make_shape(actual_seq_len), cute::make_shape(_1{})));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
return g_tensor;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
...@@ -66,18 +66,14 @@ ...@@ -66,18 +66,14 @@
} \ } \
}() }()
#define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, SEQ_LEN_OUT_OF_BOUND_CHECK, ...) \ #define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, NAME, ...) \
[&] { \ [&] { \
if (!USE_VAR_SEQ_LEN) { \ bool useSeqLen = USE_VAR_SEQ_LEN; \
if (SEQ_LEN_OUT_OF_BOUND_CHECK) { \ if (useSeqLen) { \
using kSeqLenTraitsType = FixedSeqLenTraits<true>; \ using NAME = flash::VarSeqLenTraits; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else { \
using kSeqLenTraitsType = FixedSeqLenTraits<false>; \
return __VA_ARGS__(); \
} \
} else { \ } else { \
using kSeqLenTraitsType = VarSeqLenTraits; \ using NAME = flash::FixedSeqLenTraits; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
}() }()
...@@ -5,40 +5,12 @@ import torch ...@@ -5,40 +5,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from flash_attn_interface import flash_attn_func from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
ABS_TOL = 5e-3 ABS_TOL = 5e-3
REL_TOL = 1e-1 REL_TOL = 1e-1
def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)
def print_diffs(out, out_ref): def print_diffs(out, out_ref):
out_1d = out.flatten() out_1d = out.flatten()
out_ref_1d = out_ref.flatten() out_ref_1d = out_ref.flatten()
...@@ -51,86 +23,6 @@ def print_diffs(out, out_ref): ...@@ -51,86 +23,6 @@ def print_diffs(out, out_ref):
print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}") print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}")
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
upcast=True,
reorder_ops=False,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads, head_dim)
v: (batch_size, seqlen_k, nheads, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if causal:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
(-1, 0),
None,
None,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if causal:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
...@@ -142,10 +34,11 @@ def attention_ref( ...@@ -142,10 +34,11 @@ def attention_ref(
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [56, 80])
@pytest.mark.parametrize("d", [64, 128, 256]) @pytest.mark.parametrize("d", [64, 128, 256])
# @pytest.mark.parametrize("d", [256]) # @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seqlen_q,seqlen_k", "seqlen_q,seqlen_k",
[ [
(257, 1),
(64, 128), (64, 128),
(128, 128), (128, 128),
(256, 256), (256, 256),
...@@ -175,8 +68,9 @@ def test_flash_attn_output( ...@@ -175,8 +68,9 @@ def test_flash_attn_output(
batch_size = 9 batch_size = 9
nheads = 6 nheads = 6
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
# batch_size = 1 # nheads_kv = 2
# nheads = 1 # batch_size = 9
# nheads = 6
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
...@@ -244,9 +138,172 @@ def test_flash_attn_output( ...@@ -244,9 +138,172 @@ def test_flash_attn_output(
# Check that FlashAttention's numerical error is at most twice the numerical error # Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation. # of a Pytorch implementation.
# breakpoint()
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
# if d <= 128: # if d <= 128:
# assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() # assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
# assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() # assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
# assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() # assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('causal', [True])
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [128])
@pytest.mark.parametrize("d", [64, 128, 256])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 1),
(1, 3),
(2, 1),
(511, 1),
(3, 513),
(64, 128),
(113, 203),
(128, 128),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(384, 256),
(512, 256),
(640, 128),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, causal, mha_type, dtype
):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
device = "cuda"
# set seed
torch.random.manual_seed(0)
# batch_size = 1
# nheads = 1
batch_size = 9
nheads = 6
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
# print("cu_seqlens_q: ", cu_seqlens_q)
# print("cu_seqlens_k: ", cu_seqlens_k)
# print("q_unpad, shape: ", q_unpad.shape)
# print("k_unpad, shape: ", k_unpad.shape)
# print("v_unpad, shape: ", v_unpad.shape)
out_unpad, sm_lse = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
causal=causal,
)
out = output_pad_fn(out_unpad)
dropout_mask = None
out_ref, attn_ref = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
causal=causal,
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# g = torch.randn_like(out)
# if d <= 128:
# (
# dq_unpad,
# dk_unpad,
# dv_unpad,
# ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
# dk = dk_pad_fn(dk_unpad)
# dv = dk_pad_fn(dv_unpad)
# (
# dq_ref,
# dk_ref,
# dv_ref,
# ) = torch.autograd.grad(out_ref, (q, k, v), g)
# (
# dq_pt,
# dk_pt,
# dv_pt,
# ) = torch.autograd.grad(out_pt, (q, k, v), g)
# dq = dq_pad_fn(dq_unpad)
# print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
# print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
# print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
# print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
# print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
# print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
# print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
# print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
# print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
# print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
# print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
# print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
# if d <= 128:
# assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
# assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
# assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#endif #endif
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cutlass/array.h> #include <cutlass/array.h>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
...@@ -228,4 +229,93 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layou ...@@ -228,4 +229,93 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layou
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <int NumCopyThreads, typename ElemO, typename TMACopyO, typename LayoutO,
typename TileShapeO, typename SMemO, typename SeqLenTraits>
__forceinline__ __device__ void write_tma(
ElemO* O, const TMACopyO& tma_store_O,
const LayoutO& layout_O, const TileShapeO& tile_shape_O,
const SMemO& sO, int m_block, int bidh, int bidb,
const SeqLenTraits& seqlen_traits_o, int write_warp_idx) {
Tensor mO = tma_store_O.get_tma_tensor(layout_O.shape());
Tensor gO = seqlen_traits_o.get_local_tile_tensor(
mO, tile_shape_O, bidh, bidb
)(_, _, m_block); // (M, K)
auto block_tma_O = tma_store_O.get_slice(_0{});
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == write_warp_idx && lane_predicate) {
cute::copy(tma_store_O, tOsO, tOgO);
tma_store_arrive();
}
// Note: no wait here.
// tma_store_wait<0>();
}
template <int NumCopyThreads, typename ElemO, typename TiledCopyO, typename LayoutO,
typename TileShapeO, typename SMemO, typename SeqLenTraits>
__forceinline__ __device__ void write_tiled(
ElemO* O, const TiledCopyO& tiled_copy_O,
const LayoutO& layout_O, const TileShapeO& tile_shape_O,
const SMemO& sO, int m_block, int bidh, int bidb,
const SeqLenTraits& seqlen_traits_o) {
Tensor mO = make_tensor(make_gmem_ptr(O), layout_O);
Tensor gO = seqlen_traits_o.get_local_tile_tensor(
mO, tile_shape_O, bidh, bidb
)(_, _, m_block); // (M, K)
ThrCopy thr_copy_O = tiled_copy_O.get_slice(threadIdx.x - NumCopyThreads);
Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K,k)
Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K)
// Prepare for TiledCopy.
// Grouping is needed because cute::copy_if() does group_modes<1, R> for src and dst.
// After grouping, the first dim is number of elements to read together.
Tensor tOsOFlatten = cute::flatten(tOsO);
Tensor tOsOGroup = cute::group_modes<1, rank(tOsOFlatten)>(tOsOFlatten);
Tensor tOgOFlatten = cute::flatten(tOgO);
Tensor tOgOGroup = cute::group_modes<1, rank(tOgOFlatten)>(tOgOFlatten);
// Get thread coords to global index mapping.
Tensor gOCounting = cute::make_identity_tensor(gO.shape());
Tensor tSgOCounting = thr_copy_O.partition_D(gOCounting);
Tensor tSgOCountingFlatten = cute::flatten(tSgOCounting);
Tensor tSgOCountingGrouped =
cute::group_modes<1, rank(tSgOCountingFlatten)>(tSgOCountingFlatten);
// Write out to GMEM.
const int kNumMsPerTile = get<0>(tile_shape_O);
int cta_m = std::min(
seqlen_traits_o.actual_seq_len - m_block * kNumMsPerTile, kNumMsPerTile
);
if (cta_m == kNumMsPerTile) {
copy(tiled_copy_O, tOsOGroup, tOgOGroup);
} else {
auto predicate_fn = [&](auto coords) {
auto s_coords = tSgOCountingGrouped(_0{}, coords);
return elem_less(get<0>(s_coords), cta_m);
};
copy_if(tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
}
}
template <bool IsTMACopy, int NumCopyThreads, typename ElemO,
typename TMACopyO, typename TiledCopyO, typename LayoutO,
typename TileShapeO, typename SMemO, typename SeqLenTraits>
__forceinline__ __device__ void write_O(
ElemO* O, const TMACopyO& tma_copy_O, const TiledCopyO& tiled_copy_O,
const LayoutO& layout_O, const TileShapeO& tile_shape_O,
const SMemO& sO, int m_block, int bidh, int bidb,
const SeqLenTraits& seqlen_traits_o, int write_warp_idx) {
if constexpr (IsTMACopy) {
write_tma<NumCopyThreads>(O, tma_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o, write_warp_idx);
} else {
write_tiled<NumCopyThreads>(O, tiled_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash } // namespace flash
import math
import torch
from einops import rearrange, repeat
from flash_attn.bert_padding import pad_input, unpad_input
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
)
elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
)
return padding_mask
def generate_qkv(
q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
output_pad_fn = lambda output_unpad: pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else:
dkv_pad_fn = lambda dkv_unpad: rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
else:
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
key_leftpad=None,
):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
if key_leftpad is not None:
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
key_leftpad=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if softcap > 0:
scores /= softcap
scores = scores.tanh()
scores *= softcap
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
key_leftpad=key_leftpad,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment