Commit 8d1b169e authored by Tri Dao's avatar Tri Dao
Browse files

Simplify SmemLayoutVtransposed in kernel_traits.h

parent c9861a03
...@@ -975,9 +975,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -975,9 +975,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Layout p_l = tPrP.layout(); // Layout p_l = tPrP.layout();
// Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
// flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); // flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
// flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); // flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
......
...@@ -369,11 +369,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -369,11 +369,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi<Is_causal>( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
alibi_slope alibi_slope
); );
...@@ -444,7 +444,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -444,7 +444,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
} }
// if (cute::thread0()) { print(tOrP); } // if (cute::thread0()) { print(tOrP); }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration // This check is at the end of the loop since we always have at least 1 iteration
...@@ -483,19 +483,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -483,19 +483,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi<Is_causal>( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
alibi_slope alibi_slope
); );
} }
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local( flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k, scores, n_block * kBlockN, binfo.actual_seqlen_k,
...@@ -528,7 +528,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -528,7 +528,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
block_row_idx, block_col_idx, kNWarps); block_row_idx, block_col_idx, kNWarps);
} }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
} }
// Epilogue // Epilogue
...@@ -977,11 +977,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -977,11 +977,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi<Is_causal>( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
alibi_slope alibi_slope
); );
...@@ -1027,7 +1027,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1027,7 +1027,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration // This check is at the end of the loop since we always have at least 1 iteration
...@@ -1069,11 +1069,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1069,11 +1069,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi<Is_causal>( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
alibi_slope alibi_slope
); );
...@@ -1094,7 +1094,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1094,7 +1094,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
} }
// Epilogue // Epilogue
......
...@@ -91,20 +91,10 @@ struct Flash_fwd_kernel_traits : public Base { ...@@ -91,20 +91,10 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ{}, SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{})); Shape<Int<kBlockN>, Int<kHeadDim>>{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128 // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, using SmemLayoutVtransposed = decltype(
Stride<_1, Int<kBlockKSmem>>>; composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutAtomVtransposed = decltype( using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomVtransposedNoSwizzle{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype( using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
...@@ -247,19 +237,9 @@ struct Flash_bwd_kernel_traits : public Base { ...@@ -247,19 +237,9 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV{}, SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, using SmemLayoutKtransposed = decltype(
Stride<_1, Int<kBlockKSmem>>>; composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutAtomKtransposed = decltype( using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
using SmemLayoutKtransposed = decltype(tile_to_shape(
SmemLayoutAtomKtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomKtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN // TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
...@@ -277,30 +257,15 @@ struct Flash_bwd_kernel_traits : public Base { ...@@ -277,30 +257,15 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutPdS = decltype(tile_to_shape( using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{}, SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{}))); make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>, using SmemLayoutPdStransposed = decltype(
Stride<_1, Int<kPBlockN>>>; composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutAtomPdStransposed = decltype( using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
using SmemLayoutPdStransposed = decltype(tile_to_shape(
SmemLayoutAtomPdStransposed{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomPdStransposedNoSwizzle{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>, using SmemLayoutQdOtransposed = decltype(
Stride<_1, Int<kBlockKSmem>>>; composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutAtomQdOtransposed = decltype( using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using SmemLayoutAtomdKV = decltype( using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
......
...@@ -162,9 +162,9 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 ...@@ -162,9 +162,9 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy> typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) { ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
...@@ -188,10 +188,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { ...@@ -188,10 +188,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -207,13 +204,9 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { ...@@ -207,13 +204,9 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
static_assert(mma_shape_K == 8 || mma_shape_K == 16); static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
// TD [2023-08-13]: Same error as above on Cutlass 3.2 return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), get<0, 1>(l),
// get<0, 1>(l), get<1, 1, 1>(l));
// get<1, 1, 1>(l));
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
get<1>(get<0>(l)),
get<1>(get<1>(get<1>(l))));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment