Commit 3524e13c authored by Tri Dao's avatar Tri Dao
Browse files

Update to Cutlass 3.1

parent 364a5b4a
Subproject commit c4f6b8c6bc94ff69048492fb34df0dfaf1983933 Subproject commit 6f47420213f757831fae65c686aa471749fa8d60
This diff is collapsed.
...@@ -77,7 +77,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T ...@@ -77,7 +77,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
flash::reduce_sum(scores, scores_sum); flash::reduce_sum(scores, scores_sum);
} else { } else {
Tensor scores_max_prev = make_fragment_like(scores_max); Tensor scores_max_prev = make_fragment_like(scores_max);
copy(scores_max, scores_max_prev); cute::copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max); flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
...@@ -103,7 +103,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T ...@@ -103,7 +103,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy> template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem( inline __device__ void write_softmax_to_gmem(
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
) { ) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout(); Layout l = tOrP.layout();
...@@ -112,7 +112,7 @@ inline __device__ void write_softmax_to_gmem( ...@@ -112,7 +112,7 @@ inline __device__ void write_softmax_to_gmem(
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) { for (int mi = 0; mi < size<1>(tPrP); ++mi) {
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
} }
}; };
...@@ -186,8 +186,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -186,8 +186,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
...@@ -209,16 +211,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -209,16 +211,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Copy Atom retiling // Copy Atom retiling
// //
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();} // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70 // TODO: this might need to change if we change the mma instruction in SM70
...@@ -269,7 +274,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -269,7 +274,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQrQ = make_fragment_like(tQgQ); Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
...@@ -286,13 +291,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -286,13 +291,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads(); __syncthreads();
} }
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, flash::copy<Is_even_N, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN); binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence(); cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
...@@ -303,7 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -303,7 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
} }
auto seeds = at::cuda::philox::unpack(params.philox_args); auto seeds = at::cuda::philox::unpack(params.philox_args);
...@@ -335,17 +340,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -335,17 +340,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV // Advance gV
if (masking_step > 0) { if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
); );
// if (cute::thread0()) { print(acc_s); } // if (cute::thread0()) { print(acc_s); }
...@@ -382,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -382,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
if (n_block > 0) { if (n_block > 0) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -402,12 +408,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -402,12 +408,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
uint32_t block_col_idx = n_block * (kBlockN / 32); uint32_t block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps block_row_idx, block_col_idx, kNWarps
); );
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN); tPgP.data() = tPgP.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
...@@ -416,7 +422,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -416,7 +422,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
} }
// if (cute::thread0()) { print(tOrP); } // if (cute::thread0()) { print(tOrP); }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration // This check is at the end of the loop since we always have at least 1 iteration
...@@ -434,11 +440,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -434,11 +440,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
// Advance gV // Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
); );
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
...@@ -446,7 +453,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -446,7 +453,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
if (n_block > 0) { if (n_block > 0) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -464,12 +471,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -464,12 +471,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
uint32_t block_col_idx = n_block * (kBlockN / 32); uint32_t block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps block_row_idx, block_col_idx, kNWarps
); );
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN); tPgP.data() = tPgP.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
...@@ -477,7 +484,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -477,7 +484,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
block_row_idx, block_col_idx, kNWarps); block_row_idx, block_col_idx, kNWarps);
} }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
} }
// Epilogue // Epilogue
...@@ -501,7 +508,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -501,7 +508,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor rO = flash::convert_type<Element>(acc_o); Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning // Partition sO to match the accumulator partitioning
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
...@@ -509,7 +517,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -509,7 +517,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// sO has the same size as sQ, so we don't need to sync here. // sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
copy(smem_thr_copy_O, taccOrO, taccOsO); cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
...@@ -520,14 +528,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -520,14 +528,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse), Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{}); Shape<Int<kBlockM>>{}, Stride<_1>{});
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
__syncthreads(); __syncthreads();
Tensor tOrO = make_tensor<Element>(shape(tOgO)); Tensor tOrO = make_tensor<Element>(shape(tOgO));
copy(gmem_thr_copy_O, tOsO, tOrO); cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
...@@ -554,7 +563,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -554,7 +563,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
} }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
); );
} }
......
...@@ -173,10 +173,12 @@ static __device__ inline T run(T x, Operator &op) { ...@@ -173,10 +173,12 @@ static __device__ inline T run(T x, Operator &op) {
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1, template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy0, typename TiledCopy1> typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma, Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
...@@ -184,13 +186,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 ...@@ -184,13 +186,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll #pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) { for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) { if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
} }
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
} }
...@@ -199,19 +201,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 ...@@ -199,19 +201,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy> typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll #pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) { for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) { if (i < size<2>(tCrA) - 1) {
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
} }
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
} }
...@@ -319,7 +322,7 @@ void cp_async_wait() { ...@@ -319,7 +322,7 @@ void cp_async_wait() {
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S, inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) { Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
...@@ -335,13 +338,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const & ...@@ -335,13 +338,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
#pragma unroll #pragma unroll
for (int k = 0; k < size<2>(S); ++k) { for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) { if (Is_even_K || predicate_K(k)) {
copy(thr_copy, S(_, m, k), D(_, m, k)); cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) { } else if (Clear_OOB_K) {
clear(D(_, m, k)); cute::clear(D(_, m, k));
} }
} }
} else if (Clear_OOB_MN) { } else if (Clear_OOB_MN) {
clear(D(_, m, _)); cute::clear(D(_, m, _));
} }
} }
// TD [2023-04-13]: Strange that the code below can cause race condition. // TD [2023-04-13]: Strange that the code below can cause race condition.
...@@ -350,7 +353,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const & ...@@ -350,7 +353,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll // #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) { // for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, _), D(_, m, _)); // copy(tiled_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) { // } else if (Clear_OOB_MN) {
// clear(D(_, m, _)); // clear(D(_, m, _));
// } // }
...@@ -362,7 +365,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const & ...@@ -362,7 +365,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll // #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) { // for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, k), D(_, m, k)); // copy(tiled_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) { // } else if (Clear_OOB_MN) {
// clear(D(_, m, k)); // clear(D(_, m, k));
// } // }
......
...@@ -783,13 +783,13 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_ ...@@ -783,13 +783,13 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
# @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
@pytest.mark.parametrize('d', [128]) # @pytest.mark.parametrize('d', [128])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize('seqlen', [128])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment