Commit 7fb9b2b6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/layernorm_fusion

parents 50f67a66 3d609534
...@@ -39,7 +39,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -39,7 +39,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -75,22 +76,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -75,22 +76,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -270,7 +271,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -270,7 +271,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -42,7 +42,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -42,7 +42,8 @@ struct BlockFmhaPipelineQRKSVS
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS ...@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -43,7 +43,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -43,7 +43,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
...@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync
return 1; return 1;
} }
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking) FmhaMask::IsMasking)
...@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2; return 2;
else else
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -339,7 +340,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -339,7 +340,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// auto q_tile = q; // tile_elementwise_in(q_element_func, q); // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops); static_assert(1 <= k0_loops);
......
...@@ -41,7 +41,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -41,7 +41,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -41,7 +41,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -41,7 +41,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -56,22 +57,22 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -56,22 +57,22 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -235,7 +236,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -235,7 +236,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -55,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -55,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr index_t MWarp = config.template at<1>(); constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
...@@ -323,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -323,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template<> struct template<> struct
LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; }; LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; };
template<> struct template<> struct
LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;}; LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;};
// clang-format on // clang-format on
...@@ -335,9 +338,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -335,9 +338,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kN0 = BlockFmhaShape::kN0; constexpr index_t kN0 = BlockFmhaShape::kN0;
constexpr index_t kK0 = BlockFmhaShape::kK0; constexpr index_t kK0 = BlockFmhaShape::kK0;
constexpr index_t kK1 = BlockFmhaShape::kK1; constexpr index_t kK1 = BlockFmhaShape::kK1;
constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{}; return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{};
......
...@@ -7,6 +7,20 @@ ...@@ -7,6 +7,20 @@
namespace ck_tile { namespace ck_tile {
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
{
if(len == 96)
return 128;
if(len == 160)
return 256;
// only length of 96, 160 and power-of-two is supported
if(!(len & (len - 1)))
return len;
return 0;
};
template <typename BlockTile_, // sequence<... template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_, typename Gemm0BlockWarps_,
typename Gemm0WarpTile_, typename Gemm0WarpTile_,
...@@ -36,10 +50,12 @@ struct TileFmhaShape ...@@ -36,10 +50,12 @@ struct TileFmhaShape
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kK0BlockLength = static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile) // once (or repeately load Q as a whole tile)
static_assert(kK0BlockLength % kK0 == 0, "kK0BlockLength should be divisible by kK0"); static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
......
...@@ -41,9 +41,9 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -41,9 +41,9 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
return "bpr"; // block per row return "bpr_op"; // block per row
else else
return "wpr"; // warp per row return "wpr_op"; // warp per row
}(); }();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
......
...@@ -40,9 +40,9 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -40,9 +40,9 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
return "bpr"; // block per row return "bpr_tp"; // block per row
else else
return "wpr"; // warp per row return "wpr_tp"; // warp per row
}(); }();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
...@@ -151,8 +151,6 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -151,8 +151,6 @@ struct Layernorm2dFwdPipelineTwoPass
ck_tile::index_t stride_to_right_most_window = ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
// x_window.foo();
// gamma_window.foo();
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(sx_window, {0, -Block_N}); move_tile_window(sx_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(gamma_window, {stride_to_right_most_window});
......
...@@ -4,4 +4,7 @@ ...@@ -4,4 +4,7 @@
#pragma once #pragma once
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include <tuple> #include <tuple>
// This file is not support cross warp reduce
namespace ck_tile { namespace ck_tile {
/* /*
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2d
{
// in-thread reduction
using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType;
CK_TILE_DEVICE constexpr BlockReduce2d() {}
template <typename XDistributedTensor_, typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
YDistributedTensor_& y_tensor,
const ReduceFunc& reduce_func)
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
// FIXME: hard coded to reduce 2nd axis
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
auto y = y_tensor[y_dstr_idx];
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
y = reduce_func(y, x);
});
y_tensor(y_dstr_idx) = y;
});
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeYBlockTile()
{
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
// FIXME: hard coded to reduce 2nd axis
constexpr auto reduce_dims = sequence<1>{};
constexpr auto dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
XDistributedTensor_::get_tile_distribution()
.get_static_tile_distribution_encoding(),
reduce_dims));
auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
return tensor;
}
template <typename XDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor,
const ComputeDataType& reduce_init,
const ReduceFunc& reduce_func)
{
auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
set_tile(y_tensor, reduce_init);
(*this)(x_tensor, y_tensor, reduce_func);
return y_tensor;
}
};
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2dSync
{
using Problem = remove_cvref_t<Problem_>;
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
{
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = y_tensor.get_thread_buffer()[i];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// xor
index_t src_lane =
(__lane_id()) ^
(number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
// reduce
v_local = reduce_func(v_local, v_remote);
});
}
});
// TODO - Do we need to broadcast to other lane?
y_tensor.get_thread_buffer()(i) = v_local;
});
}
};
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2dCrossWarpSync
{
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
template <typename YDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
{
constexpr index_t num_reduce_warps = [&]() {
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_warp = 0;
index_t len_ = 1;
static_for<0, NDimR, 1>{}([&](auto idim_r) {
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
len_ *= r_length;
}
});
return len_;
}();
return num_reduce_warps;
}
// return in byte
template <typename YDistributedTensor_>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
using DataType = typename YDistributedTensor_::DataType;
// constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
// we need to store all data from every wave into smem
// e.g. 2x2 reduce along N
// -------------> reduce N
// | w0 | w1 | ___> | w01 |
// | w2 | w3 | | w23 |
//
// -> store data from every wave into LDS
//
//
// -------------> reduce N
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
return num_warps * thread_buf_size * sizeof(DataType);
}
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
{
using DataType = typename YDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
const index_t smem_offset = warp_id;
// skip if nonthing to do
if constexpr(num_reduce_warps == 1)
return;
// store into smem only for lane-0 within one warp
if(lane_id == 0)
{
static_for<0, thread_buf_size, 1>{}([&](auto i) {
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
});
}
block_sync_lds();
// load from smem. here we let everythread to do compute :)
index_t local_warp_id = warp_id / num_reduce_warps;
index_t local_smem_os = local_warp_id * num_reduce_warps;
DataType all_scratch[thread_buf_size * num_reduce_warps];
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
all_scratch[i_0 * num_reduce_warps + i_1] =
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
});
});
block_sync_lds(); // TODO: we don't need sync here
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
// TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps];
// further reduce mean/var
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{};
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
// reduce
v_local = reduce_func(v_local, v_remote);
});
y_tensor.get_thread_buffer()(i_0) = v_local;
});
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace ck_tile {
struct BlockReduce2dDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2d<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
MakeXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
}
else
{
return 1; // zero size arrays are an extension
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
struct BlockReduce2dProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template <typename BlockTile_, // block size, seq<M, N>
typename WarpPerBlock_, // num warps along seq<M, N>
typename WarpTile_, // warp size, seq<M, N>
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
index_t BlockSize_ =
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
struct Rmsnorm2dShape
{
// block size
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
// num warps along seq<M, N>, within each block
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
// warp size
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
// repeat of each thread along seq<M, N>
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
// vector size along seq<M, N>
static constexpr index_t Vector_M = Vector_::at(number<0>{});
static constexpr index_t Vector_N = Vector_::at(number<1>{});
static_assert(Warp_M % Vector_M == 0);
static_assert(Warp_N % Vector_N == 0);
// num of threads along seq<M, N>, within each warp
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t BlockSize = BlockSize_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment