Commit cc1898fc authored by carlushuang's avatar carlushuang
Browse files

fix name

parent dc1c2bf8
...@@ -53,9 +53,9 @@ struct Layernorm2dFwd ...@@ -53,9 +53,9 @@ struct Layernorm2dFwd
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass; static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr index_t Thread_N = Problem::BlockShape::Thread_N; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
...@@ -125,7 +125,7 @@ struct Layernorm2dFwd ...@@ -125,7 +125,7 @@ struct Layernorm2dFwd
#define _SS_ std::string #define _SS_ std::string
#define _TS_ std::to_string #define _TS_ std::to_string
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" + return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::BlockWarps_M) + "x" + _TS_(S_::BlockWarps_N) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix; _SS_(Pipeline::name) + surfix;
#undef _SS_ #undef _SS_
......
...@@ -9,20 +9,19 @@ namespace ck_tile { ...@@ -9,20 +9,19 @@ namespace ck_tile {
/* /*
// clang-format off // clang-format off
4-level descriptor: BlockTile-> BlockWarps-> WarpTile-> Vector 4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * BlockWarps_N * Repeat_N ) Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+ +<----------------------< Repeat_N(2)>--------------------->+
| | | |
+<-- <BlockWarps_N(2)> -->+ +<-- <WarpPerBlock_N(2)> -->+
Warp_M Warp_M
+--------------+--------------+--------------+--------------+----+----------------+ +--------------+--------------+--------------+--------------+----+----------------+
Warp_N | wrap_0 | wrap_1 | | ^ ^ Warp_N | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <BlockWarps_M(2)> | +--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v | wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M +--------------+--------------+--------------+--------------+----+ Block_M
| | | (Warp_M * | | | (Warp_M * WarpPerBlock_M * Repeat_M )
BlockWarps_M * Repeat_M )
+ + | + + |
| | | v | | | v
+--------------+--------------+--------------+--------------+ + +--------------+--------------+--------------+--------------+ +
...@@ -37,12 +36,12 @@ BlockWarps_M * Repeat_M ) ...@@ -37,12 +36,12 @@ BlockWarps_M * Repeat_M )
+-----------+-----------+-----------+-----------+-----------+ +-----------+-----------+-----------+-----------+-----------+
// clang-format on // clang-format on
*/ */
template <typename BlockTile_, // block size, seq<M, N> template <typename BlockTile_, // block size, seq<M, N>
typename BlockWarps_, // num warps along seq<M, N> typename WarpPerBlock_, // num warps along seq<M, N>
typename WarpTile_, // warp size, seq<M, N> typename WarpTile_, // warp size, seq<M, N>
typename Vector_, // contiguous pixels(vector size) along seq<M, N> typename Vector_, // contiguous pixels(vector size) along seq<M, N>
index_t BlockSize_ = index_t BlockSize_ =
warpSize* reduce_on_sequence(BlockWarps_{}, multiplies{}, number<1>{})> warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
struct Layernorm2dShape struct Layernorm2dShape
{ {
// block size // block size
...@@ -50,18 +49,18 @@ struct Layernorm2dShape ...@@ -50,18 +49,18 @@ struct Layernorm2dShape
static constexpr index_t Block_N = BlockTile_::at(number<1>{}); static constexpr index_t Block_N = BlockTile_::at(number<1>{});
// num warps along seq<M, N>, within each block // num warps along seq<M, N>, within each block
static constexpr index_t BlockWarps_M = BlockWarps_::at(number<0>{}); static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
static constexpr index_t BlockWarps_N = BlockWarps_::at(number<1>{}); static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
// warp size // warp size
static constexpr index_t Warp_M = WarpTile_::at(number<0>{}); static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
static constexpr index_t Warp_N = WarpTile_::at(number<1>{}); static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
static_assert(Block_M % (BlockWarps_M * Warp_M) == 0); static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
static_assert(Block_N % (BlockWarps_N * Warp_N) == 0); static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
// repeat of each thread along seq<M, N> // repeat of each thread along seq<M, N>
static constexpr index_t Repeat_M = Block_M / (BlockWarps_M * Warp_M); static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (BlockWarps_N * Warp_N); static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
// vector size along seq<M, N> // vector size along seq<M, N>
static constexpr index_t Vector_M = Vector_::at(number<0>{}); static constexpr index_t Vector_M = Vector_::at(number<0>{});
...@@ -70,8 +69,8 @@ struct Layernorm2dShape ...@@ -70,8 +69,8 @@ struct Layernorm2dShape
static_assert(Warp_M % Vector_M == 0); static_assert(Warp_M % Vector_M == 0);
static_assert(Warp_N % Vector_N == 0); static_assert(Warp_N % Vector_N == 0);
// num of threads along seq<M, N>, within each warp // num of threads along seq<M, N>, within each warp
static constexpr index_t Thread_M = Warp_M / Vector_M; static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t Thread_N = Warp_N / Vector_N; static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t BlockSize = BlockSize_; static constexpr index_t BlockSize = BlockSize_;
}; };
......
...@@ -19,8 +19,8 @@ struct Layernorm2dFwdWarpPerRowDefaultPolicy ...@@ -19,8 +19,8 @@ struct Layernorm2dFwdWarpPerRowDefaultPolicy
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<S::BlockWarps_M, S::Thread_M, S::Vector_M>, tuple<sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
sequence<S::Repeat_N, S::BlockWarps_N, S::Thread_N, S::Vector_N>>, sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>, tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<1, 2>>, tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2, 2>, sequence<1, 2, 2>,
...@@ -33,8 +33,8 @@ struct Layernorm2dFwdWarpPerRowDefaultPolicy ...@@ -33,8 +33,8 @@ struct Layernorm2dFwdWarpPerRowDefaultPolicy
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<S::BlockWarps_M, S::Thread_M>, sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::BlockWarps_N, S::Thread_N, S::Vector_N>>, tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>, tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>, tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>, sequence<1, 1>,
......
...@@ -29,8 +29,8 @@ struct Layernorm2dFwdWarpPerRowProblem ...@@ -29,8 +29,8 @@ struct Layernorm2dFwdWarpPerRowProblem
using InvStdDataType = remove_cvref_t<InvStdDataType_>; using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::Thread_N > 1; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::BlockWarps_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
......
...@@ -324,11 +324,11 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_ ...@@ -324,11 +324,11 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_
{ {
using S = BlockShape; using S = BlockShape;
index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N; index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
constexpr index_t NThread = S::BlockWarps_N * S::Thread_N; constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
index_t iNLane = get_thread_id() % NThread; index_t iNLane = get_thread_id() % NThread;
index_t iN0 = LastloopN / (S::Vector_N * S::Thread_N); index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
index_t iN1 = (LastloopN % (S::Vector_N * S::Thread_N)) / S::Vector_N; index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
index_t N2 = (LastloopN % (S::Vector_N * S::Thread_N)) % S::Vector_N; index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0; index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
return iN0 * S::Vector_N + iN3; return iN0 * S::Vector_N + iN3;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment