"vscode:/vscode.git/clone" did not exist on "b85bb0753ee9bfde6ba4e9217203f8d820aef4c2"
Commit dbd79237 authored by Tri Dao's avatar Tri Dao
Browse files

Prepare for Cutlass 3.2

parent c5e87b11
...@@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base { ...@@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ{}, SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{})); Shape<Int<kBlockN>, Int<kHeadDim>>{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomVtransposed = decltype( using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape( using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{}, SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{})); Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape // Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter? // And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomVtransposedNoSwizzle{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype( using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
...@@ -223,16 +226,19 @@ struct Flash_bwd_kernel_traits : public Base { ...@@ -223,16 +226,19 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV{}, SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomKtransposed = decltype( using SmemLayoutAtomKtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutKtransposed = decltype(tile_to_shape( using SmemLayoutKtransposed = decltype(tile_to_shape(
SmemLayoutAtomKtransposed{}, SmemLayoutAtomKtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{}))); make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape // Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter? // And the strides don't matter?
using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomKtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN // TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
...@@ -250,24 +256,30 @@ struct Flash_bwd_kernel_traits : public Base { ...@@ -250,24 +256,30 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutPdS = decltype(tile_to_shape( using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{}, SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{}))); make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
Stride<_1, Int<kPBlockN>>>;
using SmemLayoutAtomPdStransposed = decltype( using SmemLayoutAtomPdStransposed = decltype(
composition(Swizzle<kSwizzlePdS, 3, 3>{}, composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
Stride<_1, Int<kPBlockN>>>{}));
using SmemLayoutPdStransposed = decltype(tile_to_shape( using SmemLayoutPdStransposed = decltype(tile_to_shape(
SmemLayoutAtomPdStransposed{}, SmemLayoutAtomPdStransposed{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{}))); make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomPdStransposedNoSwizzle{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomQdOtransposed = decltype( using SmemLayoutAtomQdOtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutQdOtransposed = decltype(tile_to_shape( using SmemLayoutQdOtransposed = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposed{}, SmemLayoutAtomQdOtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{}))); make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using SmemLayoutAtomdKV = decltype( using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
......
...@@ -228,7 +228,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { ...@@ -228,7 +228,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -244,9 +247,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { ...@@ -244,9 +247,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
static_assert(mma_shape_K == 8 || mma_shape_K == 16); static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), // TD [2023-08-13]: Same error as above on Cutlass 3.2
get<0, 1>(l), // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
get<1, 1, 1>(l)); // get<0, 1>(l),
// get<1, 1, 1>(l));
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
get<1>(get<0>(l)),
get<1>(get<1>(get<1>(l))));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment