Commit 9e5e8bc9 authored by Tri Dao's avatar Tri Dao
Browse files

Change causal mask to be aligned to bottom-right instead of top-left

parent e07aa036
...@@ -136,6 +136,32 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) ...@@ -136,6 +136,32 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```python ```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
``` ```
## Changes in v2.1 (compared to v2.0)
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
v2.0:
1 0 0 0 0
1 1 0 0 0
v2.1:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
v2.0:
1 0
1 1
1 1
1 1
1 1
v2.1:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
## Performance ## Performance
......
...@@ -15,12 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func ...@@ -15,12 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
# from triton.ops.flash_attention import attention as attention_triton # from triton.ops.flash_attention import attention as attention_triton
try: from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
from fav2 import flash_attn_qkvpacked_func as fav2_qkvpacked_func
from fav2 import flash_attn_kvpacked_func as fav2_kvpacked_func
except ImportError:
fav2_qkvpacked_func = None
fav2_kvpacked_func = None
try: try:
from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
...@@ -80,8 +75,8 @@ def attention_megatron(qkv): ...@@ -80,8 +75,8 @@ def attention_megatron(qkv):
torch.manual_seed(0) torch.manual_seed(0)
repeats = 30 repeats = 30
batch_size = 2 batch_size = 8
seqlen = 8192 seqlen = 2048
nheads = 12 nheads = 12
headdim = 128 headdim = 128
# nheads = 24 # nheads = 24
...@@ -90,8 +85,8 @@ headdim = 128 ...@@ -90,8 +85,8 @@ headdim = 128
# seqlen = 512 # seqlen = 512
# nheads = 8 # nheads = 8
# headdim = 128 # headdim = 128
dropout_p = 0.1 dropout_p = 0.0
causal = False causal = True
dtype = torch.float16 dtype = torch.float16
device = 'cuda' device = 'cuda'
...@@ -100,20 +95,20 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d ...@@ -100,20 +95,20 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device) device=qkv.device)
# qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True) qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad, # benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention') # cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad, # pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True) # cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
# if fav2_qkvpacked_func is not None: benchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
# benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2') pytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False)
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
# for dropout_p in [0.1, 0.0]: # for dropout_p in [0.1, 0.0]:
# for causal in [False, True]: # for causal in [False, True]:
# print(f"### {dropout_p = }, {causal = } ###") # print(f"### {dropout_p = }, {causal = } ###")
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True) # pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
# nheads_k = 2 # nheads_k = 2
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) # q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype, # kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
...@@ -151,6 +146,7 @@ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch ...@@ -151,6 +146,7 @@ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch
flops = 4 * batch_size * seqlen ** 2 * nheads * headdim flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
ideal_a100_time = flops / 312 / 1e9 ideal_a100_time = flops / 312 / 1e9
print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms") print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
exit(0)
def time_fwd_bwd(func, *args, **kwargs): def time_fwd_bwd(func, *args, **kwargs):
......
...@@ -32,8 +32,8 @@ struct BlockInfo { ...@@ -32,8 +32,8 @@ struct BlockInfo {
const int sum_s_q; const int sum_s_q;
const int sum_s_k; const int sum_s_k;
const uint32_t actual_seqlen_q; const int actual_seqlen_q;
const uint32_t actual_seqlen_k; const int actual_seqlen_k;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -659,46 +659,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -659,46 +659,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded; tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded;
int m_block = m_block_max - 1; int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN - int(binfo.actual_seqlen_k - binfo.actual_seqlen_q)) / kBlockM; int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM);
m_block_min = m_block_min < 0 ? 0 : m_block_min; // We're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// We might need to exit early and write 0 to dK and dV. // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// Otherwise we get wrong result for the case where we don't enter the for loop. // So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// And we might read OOB elements from gQ and gdO. // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// TODO: what if we're not parallelizing, do we need to compute dot_do_o? // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
if (Is_causal && m_block < m_block_min) { // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
clear(tdKrdK);
clear(tdVrdV);
Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
return;
}
if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ
tQsQ.data() = tQsQ.data() + size(sQ); tQsQ.data() = tQsQ.data() + size(sQ);
...@@ -743,7 +711,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -743,7 +711,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{}); Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
#pragma unroll #pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { for (int mi = 0; mi < size(lse); ++mi) {
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const int row = get<0>(taccScS_row(mi)); const int row = get<0>(taccScS_row(mi));
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0;
} }
...@@ -824,11 +791,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -824,11 +791,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking. // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements beyond actual_seqlen_k. // But we still want to mask out elements beyond actual_seqlen_k.
if (m_block * kBlockM < (n_block + 1) * kBlockN if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_q, binfo.actual_seqlen_k, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
m_block * kBlockM + get<0>(taccScS_row(0)), binfo.actual_seqlen_q,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16); AtomLayoutMS * 16);
} }
...@@ -837,11 +804,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -837,11 +804,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Compute the exponential value. // Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2); flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (Is_dropout) { if (Is_dropout) {
uint32_t warp_id = tidx / 32; int warp_id = tidx / 32;
uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert(MMA_N_SdP % 2 == 0); static_assert(MMA_N_SdP % 2 == 0);
uint32_t block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(scores.layout())); Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(scores.layout()));
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
scores_dropped, params.p_dropout_in_uint8_t, seed, offset, scores_dropped, params.p_dropout_in_uint8_t, seed, offset,
...@@ -1341,7 +1308,6 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1341,7 +1308,6 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{}); Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
#pragma unroll #pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { for (int mi = 0; mi < size(lse); ++mi) {
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const int row = get<0>(taccScS_row(mi)); const int row = get<0>(taccScS_row(mi));
lse(mi) = row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; lse(mi) = row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0;
} }
...@@ -1379,18 +1345,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1379,18 +1345,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
// the corresponding values of K would be 0, so the result would still be correct. // the corresponding values of K would be 0, so the result would still be correct.
if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) { if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_q, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
AtomLayoutMS * 16); AtomLayoutMS * 16);
} }
// Compute the exponential value. // Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2); flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (Is_dropout) { if (Is_dropout) {
uint32_t warp_id = tidx / 32; int warp_id = tidx / 32;
uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert(MMA_N_SdP % 2 == 0); static_assert(MMA_N_SdP % 2 == 0);
uint32_t block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(scores.layout())); Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(scores.layout()));
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
scores_dropped, params.p_dropout_in_uint8_t, seed, offset, scores_dropped, params.p_dropout_in_uint8_t, seed, offset,
......
...@@ -118,7 +118,7 @@ inline __device__ void write_softmax_to_gmem( ...@@ -118,7 +118,7 @@ inline __device__ void write_softmax_to_gmem(
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -130,8 +130,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -130,8 +130,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
// The global block index.
const int block_id = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kBlockN = Kernel_traits::kBlockN;
...@@ -139,16 +137,60 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -139,16 +137,60 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) { if (Is_causal) {
n_block_max = std::min(n_block_max, cute::ceil_div( n_block_max = std::min(n_block_max,
(m_block + 1) * kBlockM + int(binfo.actual_seqlen_k - binfo.actual_seqlen_q), kBlockN)); cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// } // }
// We exit early and write 0 to gO and gLSE.
// Otherwise we might read OOB elements from gK and gV.
if (n_block_max <= 0) {
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
}
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<Element>(shape(tOgO));
clear(tOrO);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgO); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
}
return;
}
} }
// We iterate over the blocks in reverse order. This is because the last block is the only one // We iterate over the blocks in reverse order. This is because the last block is the only one
...@@ -275,7 +317,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -275,7 +317,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQrQ = make_fragment_like(tQgQ); Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
...@@ -298,7 +340,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -298,7 +340,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_N, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN); binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence(); cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
...@@ -317,7 +359,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -317,7 +359,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
// Save seed and offset for backward. // Save seed and offset for backward.
if (block_id == 0 && tidx == 0) { if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
params.rng_state[0] = seed; params.rng_state[0] = seed;
params.rng_state[1] = std::get<1>(seeds); params.rng_state[1] = std::get<1>(seeds);
} }
...@@ -330,7 +372,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -330,7 +372,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = !Is_causal
? 1
: (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll #pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
...@@ -344,7 +390,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -344,7 +390,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
...@@ -363,7 +409,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -363,7 +409,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN. // can produce Inf / NaN.
if (!Is_causal) { if (!Is_causal) {
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else { } else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) // Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
...@@ -376,9 +422,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -376,9 +422,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Idk why it's get<1> and not get<0> of the stride. // Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row // I can't get the stride from idx_row
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_q, binfo.actual_seqlen_k, flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)), // m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16); kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
...@@ -405,8 +452,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -405,8 +452,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32); int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
cute::copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
...@@ -468,8 +515,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -468,8 +515,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32); int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
cute::copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
...@@ -563,14 +610,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -563,14 +610,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
} }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
); );
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) { inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
...@@ -586,7 +633,7 @@ inline __device__ void compute_attn(const Params &params) { ...@@ -586,7 +633,7 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of // the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block); flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -10,9 +10,9 @@ ...@@ -10,9 +10,9 @@
#include "flash.h" #include "flash.h"
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) { __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params); flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal> template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
...@@ -26,17 +26,15 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -26,17 +26,15 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h); dim3 grid(num_m_block, params.b, params.h);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
// for cu_seqlens_q as well.
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr; const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// Will only return softmax if dropout, to reduce compilation time. // Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>; auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>; // auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
if (smem_size >= 48 * 1024) { if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
......
...@@ -117,18 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens ...@@ -117,18 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
} }
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k, inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const uint32_t col_idx_offset_ = 0) { const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8; const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = col_idx_base + j; const int col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) { if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results // Without the "make_coord" we get wrong results
#pragma unroll #pragma unroll
...@@ -141,28 +141,28 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t ...@@ -141,28 +141,28 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
} }
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_, inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const uint32_t max_seqlen_q, const uint32_t max_seqlen_k, const int max_seqlen_k, const int row_idx_offset_,
const uint32_t row_idx_offset_, const uint32_t warp_row_stride) { const int max_seqlen_q, const int warp_row_stride) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; // const int row_idx_offset = row_idx_offset_ + lane_id / 4;
const uint32_t row_idx_offset = row_idx_offset_; const int row_idx_offset = row_idx_offset_;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) { for (int i = 0; i < size<0, 0>(tensor); ++i) {
const uint32_t row_idx = row_idx_base + i * 8; const int row_idx = row_idx_base + i * 8;
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q);
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8; const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = col_idx_base + j; const int col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit) { if (col_idx >= col_idx_limit) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
} }
...@@ -180,7 +180,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u ...@@ -180,7 +180,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx( inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_)
{ {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
...@@ -189,7 +189,7 @@ inline __device__ void apply_mask_causal_w_idx( ...@@ -189,7 +189,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) { for (int mi = 0; mi < size<0>(tensor); ++mi) {
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
...@@ -207,8 +207,8 @@ inline __device__ void apply_mask_causal_w_idx( ...@@ -207,8 +207,8 @@ inline __device__ void apply_mask_causal_w_idx(
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout> template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t, inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
unsigned long long seed, unsigned long long offset, unsigned long long seed, unsigned long long offset,
uint32_t block_row_start, uint32_t block_col_start, int block_row_start, int block_col_start,
uint32_t block_row_stride) { int block_row_stride) {
// tensor has shape (8, MMA_M, MMA_N / 2) // tensor has shape (8, MMA_M, MMA_N / 2)
using T = typename Engine::value_type; using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) { auto encode_dropout = [](bool keep, T val) {
......
__version__ = "2.0.9" __version__ = "2.1.0"
from flash_attn.flash_attn_interface import ( from flash_attn.flash_attn_interface import (
flash_attn_func, flash_attn_func,
......
...@@ -528,6 +528,18 @@ def flash_attn_kvpacked_func( ...@@ -528,6 +528,18 @@ def flash_attn_kvpacked_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments: Arguments:
q: (batch_size, seqlen, nheads, headdim) q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim) kv: (batch_size, seqlen, 2, nheads_k, headdim)
...@@ -559,6 +571,18 @@ def flash_attn_func( ...@@ -559,6 +571,18 @@ def flash_attn_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments: Arguments:
q: (batch_size, seqlen, nheads, headdim) q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim) k: (batch_size, seqlen, nheads_k, headdim)
...@@ -645,6 +669,18 @@ def flash_attn_varlen_kvpacked_func( ...@@ -645,6 +669,18 @@ def flash_attn_varlen_kvpacked_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments: Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...@@ -703,6 +739,18 @@ def flash_attn_varlen_func( ...@@ -703,6 +739,18 @@ def flash_attn_varlen_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments: Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
......
This diff is collapsed.
...@@ -89,7 +89,7 @@ RUN pip install flash-attn==2.0.9 ...@@ -89,7 +89,7 @@ RUN pip install flash-attn==2.0.9
# Install CUDA extensions for cross-entropy, fused dense, layer norm # Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN git clone https://github.com/HazyResearch/flash-attention \ RUN git clone https://github.com/HazyResearch/flash-attention \
&& cd flash-attention && git checkout v2.0.9 \ && cd flash-attention && git checkout v2.1.0 \
&& cd csrc/fused_softmax && pip install . && cd ../../ \ && cd csrc/fused_softmax && pip install . && cd ../../ \
&& cd csrc/rotary && pip install . && cd ../../ \ && cd csrc/rotary && pip install . && cd ../../ \
&& cd csrc/xentropy && pip install . && cd ../../ \ && cd csrc/xentropy && pip install . && cd ../../ \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment