Unverified Commit a11cf2c6 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen_hiprtc

parents a72e9efa 64d5c4d6
...@@ -70,11 +70,16 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -70,11 +70,16 @@ struct FusedMoeGemmPipeline_FlatmmUk
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
#if 1
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize(); constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize(); constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge = constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_0, max(smem_1, smem_bridge)); return max(smem_0 + smem_1, smem_bridge);
#else
// keep it here purposely in case we have regression
return 65536;
#endif
} }
// this is the thread-offset along row/col // this is the thread-offset along row/col
...@@ -125,6 +130,9 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -125,6 +130,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
array<index_t, n_size> row_ids; array<index_t, n_size> row_ids;
static_for<0, n_size, 1>{}([&](auto i) { static_for<0, n_size, 1>{}([&](auto i) {
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans; row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
row_ids.at(i) &= 0xffffff;
#endif
}); });
return row_ids; return row_ids;
...@@ -164,9 +172,12 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -164,9 +172,12 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t sorted_tile_id, index_t sorted_tile_id,
index_t intermediate_tile_id) index_t intermediate_tile_id)
{ {
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size; ck_tile::index_t shared_intermediate_size_0 =
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0; kargs.intermediate_size * hidden_radio_0; // total gate+up
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size;
// after weight shuffling, gate-only: [nr0, kr0, w0], gate+up: [nr0_gate + nr0_up, kr0, w0]
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
...@@ -200,29 +211,35 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -200,29 +211,35 @@ struct FusedMoeGemmPipeline_FlatmmUk
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr), make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType)); kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
auto g_win = [&]() { auto make_gu_win = [&](const auto* ptr_) {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) + auto view_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<long_index_t>(expert_id) * expert_stride_0 + ptr_,
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}), make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1), make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{}, number<kAlignmentG>{},
number<1>{}); number<1>{});
auto g_window_ = make_tile_window_linear_raw( auto win_ = make_tile_window_linear_raw(
g_view_, view_,
make_tuple(number<BlockShape::Block_Nr0>{}, make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{}, number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}), number<BlockShape::Block_W0>{}),
{0, 0, 0}, {0, 0, 0},
Policy::template MakeGlobalTileDistribution_G<Problem>(), Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{}); sequence<0, 1, 1>{});
return g_window_; return win_;
}(); };
const GDataType* gu_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
auto g_win = make_gu_win(gu_ptr);
// Note: gu swizzled, [nr_u+nr_g, kr, w], hence base offset to up is just interm*hidden
auto u_win = make_gu_win(gu_ptr + kargs.intermediate_size * kargs.hidden_size);
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); }, auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
number<decltype(g_win)::NumAccess_NonLinear>{}); number<decltype(g_win)::NumAccess_NonLinear>{});
...@@ -309,28 +326,73 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -309,28 +326,73 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto w_scale = GetWeightScale( auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)); row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>(); auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res,
a_coords, auto y_pre = [&]() {
g_res, if constexpr(IsGateOnly)
g_coords, {
smem, auto acc_0 = uk_0(a_res,
kargs.hidden_size, a_coords,
BlockShape::Block_K0, // tile offset for B matrix each unroll g_res,
BlockShape::Block_Kr0 * g_coords,
BlockShape::Block_W0); // tile offset for B matrix each unroll smem,
kargs.hidden_size,
sweep_tile( BlockShape::Block_K0, // tile offset for B matrix each unroll
acc_0, BlockShape::Block_Kr0 *
[&](auto idx0, auto idx1) { BlockShape::Block_W0); // tile offset for B matrix each unroll
fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
typename Problem::GateActivation{}(v_, v_); sweep_tile(
acc_0(idx0) = v_.x; acc_0,
acc_0(idx1) = v_.y; [&](auto idx0, auto idx1) {
}, fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
sequence<1, 2>{}); typename Problem::GateActivation{}(v_, v_);
acc_0(idx0) = v_.x;
auto y_pre = cast_tile<YDataType>(acc_0); acc_0(idx1) = v_.y;
},
sequence<1, 2>{});
return cast_tile<YDataType>(acc_0);
}
else
{
uint32x8_t gu_res;
gu_res[0] = g_res[0];
gu_res[1] = g_res[1];
gu_res[2] = g_res[2];
gu_res[3] = g_res[3];
gu_res[4] = u_res[0];
gu_res[5] = u_res[1];
gu_res[6] = u_res[2];
gu_res[7] = u_res[3];
auto acc_0 = uk_0(a_res,
a_coords,
gu_res,
g_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * BlockShape::Block_W0,
bool_constant<true>{}); // tile offset for B matrix each unroll
sweep_tile(
acc_0.at(number<0>{}),
[&](auto idx0, auto idx1) {
fp32x2_t v_{acc_0.at(number<0>{})(idx0), acc_0.at(number<0>{})(idx1)};
typename Problem::GateActivation{}(v_, v_);
acc_0.at(number<0>{})(idx0) = v_.x;
acc_0.at(number<0>{})(idx1) = v_.y;
},
sequence<1, 2>{});
auto reduced_acc_0 =
tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
acc_0.at(number<0>{}),
acc_0.at(number<1>{}));
return cast_tile<YDataType>(reduced_acc_0);
}
}();
block_sync_lds(); block_sync_lds();
......
...@@ -101,9 +101,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -101,9 +101,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k); const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -174,7 +174,7 @@ struct GemmKernel ...@@ -174,7 +174,7 @@ struct GemmKernel
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
{ {
return false; return false;
} }
...@@ -185,7 +185,7 @@ struct GemmKernel ...@@ -185,7 +185,7 @@ struct GemmKernel
} }
else else
{ {
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false) if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{ {
return false; return false;
} }
...@@ -197,7 +197,7 @@ struct GemmKernel ...@@ -197,7 +197,7 @@ struct GemmKernel
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false) if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{ {
return false; return false;
} }
...@@ -208,7 +208,7 @@ struct GemmKernel ...@@ -208,7 +208,7 @@ struct GemmKernel
} }
else else
{ {
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
{ {
return false; return false;
} }
...@@ -220,7 +220,7 @@ struct GemmKernel ...@@ -220,7 +220,7 @@ struct GemmKernel
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false) if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{ {
return false; return false;
} }
...@@ -231,7 +231,7 @@ struct GemmKernel ...@@ -231,7 +231,7 @@ struct GemmKernel
} }
else else
{ {
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false) if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{ {
return false; return false;
} }
...@@ -323,17 +323,17 @@ struct GemmKernel ...@@ -323,17 +323,17 @@ struct GemmKernel
const auto& a_tensor_view = views.at(I0); const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(a_tensor_view,
a_tensor_view, make_tuple(number<TilePartitioner::MPerBlock>{},
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{}); sequence<false, GemmPipeline::kPadK>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(a_tensor_view,
a_tensor_view, make_tuple(number<TilePartitioner::MPerBlock>{},
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{}); sequence<GemmPipeline::kPadM, false>{});
} }
}(); }();
...@@ -341,17 +341,17 @@ struct GemmKernel ...@@ -341,17 +341,17 @@ struct GemmKernel
const auto& b_tensor_view = views.at(I1); const auto& b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(b_tensor_view,
b_tensor_view, make_tuple(number<TilePartitioner::NPerBlock>{},
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{}); sequence<false, GemmPipeline::kPadK>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(b_tensor_view,
b_tensor_view, make_tuple(number<TilePartitioner::NPerBlock>{},
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadN, false>{}); sequence<GemmPipeline::kPadN, false>{});
} }
}(); }();
...@@ -359,17 +359,17 @@ struct GemmKernel ...@@ -359,17 +359,17 @@ struct GemmKernel
const auto& c_tensor_view = views.at(I2); const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(c_tensor_view,
c_tensor_view, make_tuple(number<TilePartitioner::MPerBlock>{},
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{}); sequence<false, GemmPipeline::kPadN>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(c_tensor_view,
c_tensor_view, make_tuple(number<TilePartitioner::MPerBlock>{},
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{}); sequence<GemmPipeline::kPadM, false>{});
} }
}(); }();
...@@ -383,19 +383,19 @@ struct GemmKernel ...@@ -383,19 +383,19 @@ struct GemmKernel
const auto& a_pad_view = views.at(I0); const auto& a_pad_view = views.at(I0);
const auto& a_block_window = make_tile_window( const auto& a_block_window = make_tile_window(
a_pad_view, a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0}); {i_m, 0});
const auto& b_pad_view = views.at(I1); const auto& b_pad_view = views.at(I1);
const auto& b_block_window = make_tile_window( const auto& b_block_window = make_tile_window(
b_pad_view, b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0}); {i_n, 0});
const auto& c_pad_view = views.at(I2); const auto& c_pad_view = views.at(I2);
auto c_block_window = make_tile_window( auto c_block_window = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n}); {i_m, i_n});
return make_tuple(a_block_window, b_block_window, c_block_window); return make_tuple(a_block_window, b_block_window, c_block_window);
...@@ -426,7 +426,7 @@ struct GemmKernel ...@@ -426,7 +426,7 @@ struct GemmKernel
// Create Gemm tensor views, pad views and tile windows // Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
;
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
...@@ -456,7 +456,10 @@ struct GemmKernel ...@@ -456,7 +456,10 @@ struct GemmKernel
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs); const SplitKBatchOffset splitk_batch_offset(kargs);
// options // options
const ADataType* a_ptr = const ADataType* a_ptr =
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
template <typename BlockGemmShape_>
struct GemmTilePartitioner /** @brief Struct representing 2D block index mapping into 3D output tile space. */
template <typename BlockGemmShapeType>
struct GemmTile2DPartitioner
{ {
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
static constexpr index_t kM = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t kN = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t kK = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) /** @brief Returns 3D grid size. */
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) noexcept(
noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
{ {
index_t GridDimX = (M + kM - 1) / kM; const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
index_t GridDimY = (N + kN - 1) / kN; const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
index_t GridDimZ = batch_size; const index_t GridDimZ = batch_size;
return dim3(GridDimX, GridDimY, GridDimZ); return dim3(GridDimX, GridDimY, GridDimZ);
} }
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) /**
* @brief Returns the number of loops.
* @param [in] K is dimension
*/
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
{ {
return integer_divide_ceil(K, kK); return integer_divide_ceil(K, KPerBlock);
} }
CK_TILE_DEVICE auto operator()() /**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x
* @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes.
*/
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx,
index_t blockIdy) noexcept
-> const tuple<index_t, index_t>
{ {
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kM); const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kN); const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy);
return make_tuple(iM, iN); return make_tuple(iM, iN);
} }
}; };
template <typename BlockGemmShape_> /**
* @brief Struct representing 1D block index mapping into 2D output tile space.
*/
template <typename BlockGemmShapeType>
struct GemmTile1DPartitioner struct GemmTile1DPartitioner
{ {
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N) /** @brief delete default ctr with no any object */
constexpr GemmTile1DPartitioner() noexcept = delete;
/** @brief constructs an object that does contain a N value. */
constexpr GemmTile1DPartitioner(index_t N) noexcept { N_ = N; }
/** @brief Returns 1D grid size. */
CK_TILE_HOST static constexpr auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
{ {
index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
index_t GridDimY = (N + NPerBlock - 1) / NPerBlock; const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
return dim3(GridDimX * GridDimY, 1, 1); return dim3(GridDimX * GridDimY, 1, 1);
} }
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) /**
* @brief Returns the number of blocks in N.
* @param [in] N is dimension
*/
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) noexcept -> index_t
{ {
return integer_divide_ceil(N, NPerBlock); return integer_divide_ceil(N, NPerBlock);
} }
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) /**
* @brief Returns the number of loops.
* @param [in] K is dimension
*/
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
{ {
return integer_divide_ceil(K, KPerBlock); return integer_divide_ceil(K, KPerBlock);
} }
CK_TILE_DEVICE auto operator()(index_t blockOffset, index_t NBlockSize) /**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x - block_start.
* */
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx) noexcept
-> const tuple<index_t, index_t>
{
const index_t NBlock = GetNBlock(N_);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - (iM)*NBlock);
return make_tuple(iM, iN);
}
private:
CK_TILE_DEVICE static index_t N_;
};
/**
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::false specialization,
* checking expression validity in-place for ill-formed.
*/
template <typename, typename = void>
struct HasFnOneArgImpl : std::false_type
{
};
/**
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::true specialization,
* checking expression validity in-place for well-formed.
* @note: `1` - a constant value indicating the number of parameters in the function.
*/
template <typename T>
struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIndex(1))>>
: std::true_type
{
};
/**
* @brief Struct used to calculate offseted tile indexes.
* @note: The struct supports the 1D-Partitioner mechanism,
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type.
*/
template <typename PartitionerFn,
typename = typename std::enable_if_t<HasFnOneArgImpl<PartitionerFn>{}>>
struct OffsettedTile1DPartitioner
{
/**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start is `blockIdx.x - block_start`.
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index.
*/
[[nodiscard]] CK_TILE_DEVICE static constexpr auto GetOffsetedTileIndex(index_t block_start,
index_t N) noexcept
-> const tuple<index_t, index_t>
{ {
index_t iM = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) / const auto [iM, iN] = PartitionerFn(N).GetOutputTileIndex(blockIdx.x - block_start);
GetNBlock(NBlockSize) * MPerBlock);
index_t iN = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) %
GetNBlock(NBlockSize) * NPerBlock);
return make_tuple(iM, iN); return make_tuple(iM, iN);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream>
#include <string>
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
namespace ck_tile { namespace ck_tile {
struct GroupedGemmHostArgs struct GroupedGemmHostArgs : public ck_tile::GemmHostArgs
{ {
const void* a_ptr; CK_TILE_HOST GroupedGemmHostArgs() noexcept = default;
const void* b_ptr; CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
void* c_ptr; const void* b_ptr_,
index_t M; void* c_ptr_,
index_t N; ck_tile::index_t M_,
index_t K; ck_tile::index_t N_,
index_t stride_A; ck_tile::index_t K_,
index_t stride_B; ck_tile::index_t stride_A_,
index_t stride_C; ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_)
: GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, KBatch, M_, N_, K_, stride_A_, stride_B_, stride_C_)
{
}
private:
static constexpr index_t KBatch = 1;
}; };
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GroupedGemmKernel struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
{ {
using TilePartitioner = remove_cvref_t<TilePartitioner_>; using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>; using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>; using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>; using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>; using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmKernelArgs = typename Base::GemmKernelArgs;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t KBatch = 1;
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
GroupedGemmHostArgs group_karg; GemmKernelArgs group_karg;
ck_tile::index_t block_start; ck_tile::index_t block_start;
ck_tile::index_t block_end; ck_tile::index_t block_end;
GemmTransKernelArg() = default; GemmTransKernelArg() = default;
GemmTransKernelArg(GroupedGemmHostArgs&& karg, index_t bl_start, index_t bl_end) GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end} : group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{ {
} }
}; };
__host__ static size_t GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs) __host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::size_t
{ {
return gemm_descs.size() * sizeof(GemmTransKernelArg); return gemm_descs.size() * sizeof(GemmTransKernelArg);
} }
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } __host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
using Hargs = GroupedGemmHostArgs;
__host__ static constexpr auto GridSize(const std::vector<Hargs>& gemm_descs) __host__ static constexpr auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
{ {
index_t grid_size = 0; index_t grid_size = 0;
for(const auto& it_desc : gemm_descs) for(const auto& it_desc : gemm_descs)
...@@ -77,7 +84,8 @@ struct GroupedGemmKernel ...@@ -77,7 +84,8 @@ struct GroupedGemmKernel
return dim3(grid_size, 1, 1); return dim3(grid_size, 1, 1);
} }
CK_TILE_HOST static auto MakeKargs(const std::vector<Hargs>& gemm_descs) CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::vector<GemmTransKernelArg>
{ {
std::vector<GemmTransKernelArg> gemm_kernel_args_; std::vector<GemmTransKernelArg> gemm_kernel_args_;
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size()); index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
...@@ -100,22 +108,23 @@ struct GroupedGemmKernel ...@@ -100,22 +108,23 @@ struct GroupedGemmKernel
const index_t stride_c = gemm_descs[i].stride_C; const index_t stride_c = gemm_descs[i].stride_C;
const auto dim3 = TilePartitioner::GridSize(M, N); const auto dim3 = TilePartitioner::GridSize(M, N);
const index_t grid_size_grp = dim3.x * 1 * 1; const index_t grid_size_grp = dim3.x;
const index_t block_start = grid_size; const index_t block_start = grid_size;
const index_t block_end = grid_size + grid_size_grp; const index_t block_end = grid_size + grid_size_grp;
grid_size += grid_size_grp; grid_size += grid_size_grp;
auto karg = GroupedGemmHostArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr), auto karg = GemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr), type_convert<const BDataType*>(gemm_descs[i].b_ptr),
type_convert<CDataType*>(gemm_descs[i].c_ptr), type_convert<CDataType*>(gemm_descs[i].c_ptr),
M, M,
N, N,
K, K,
stride_a, stride_a,
stride_b, stride_b,
stride_c}; stride_c,
KBatch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
} }
...@@ -123,162 +132,34 @@ struct GroupedGemmKernel ...@@ -123,162 +132,34 @@ struct GroupedGemmKernel
return gemm_kernel_args_; return gemm_kernel_args_;
} }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
{ {
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_DEVICE void Run(const Hargs& kargs, const index_t block_start) const CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(block_start, kargs.N); const auto [iM, iN] =
// options OffsetTile1DPartitioner::GetOffsetedTileIndex(kargs.block_start, kargs.group_karg.N);
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<1>{},
number<1>{});
}
}();
auto b_tensor_view = [&]() { const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<1>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{},
number<1>{});
}
}();
auto a_pad_view = [&]() { const typename Base::SplitKBatchOffset splitk_batch_offset(kargs.group_karg, blockIdx.z);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on
auto a_block_window = make_tile_window( const ADataType* a_ptr = static_cast<const ADataType*>(kargs.group_karg.a_ptr);
a_pad_view, const BDataType* b_ptr = static_cast<const BDataType*>(kargs.group_karg.b_ptr);
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}), CDataType* c_ptr = static_cast<CDataType*>(kargs.group_karg.c_ptr);
{i_m, 0});
auto b_pad_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); this->RunGemm(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs.group_karg, splitk_batch_offset, i_m, i_n);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto CBlockWindow_pad = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
} }
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
int group_count) const index_t group_count) const
{ {
const index_t block_id = ck_tile::get_block_1d_id(); const index_t block_id = ck_tile::get_block_1d_id();
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>( const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
...@@ -286,7 +167,7 @@ struct GroupedGemmKernel ...@@ -286,7 +167,7 @@ struct GroupedGemmKernel
index_t left = 0; index_t left = 0;
index_t right = group_count; index_t right = group_count;
index_t group_id = index_t((left + right) / 2); index_t group_id = index_t((left + right) >> 1);
while((!(block_id >= gemm_desc_ptr[group_id].block_start && while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
block_id < gemm_desc_ptr[group_id].block_end)) && block_id < gemm_desc_ptr[group_id].block_end)) &&
...@@ -300,10 +181,10 @@ struct GroupedGemmKernel ...@@ -300,10 +181,10 @@ struct GroupedGemmKernel
{ {
left = group_id; left = group_id;
} }
group_id = index_t((left + right) / 2); group_id = index_t((left + right) >> 1);
} }
Run(gemm_desc_ptr[group_id].group_karg, gemm_desc_ptr[group_id].block_start); Run(gemm_desc_ptr[group_id]);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,7 @@ struct Layernorm2dFwdHostArgs ...@@ -14,7 +14,7 @@ struct Layernorm2dFwdHostArgs
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -43,16 +43,16 @@ struct Layernorm2dFwd ...@@ -43,16 +43,16 @@ struct Layernorm2dFwd
using Epilogue = remove_cvref_t<Epilogue_>; using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>; using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>; using YDataType = remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>; using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>; using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
// for simplicity, shortcut input/output type is same as X // for simplicity, shortcut input/output type is same as X
using XResidualDataType = XDataType; using XResidualDataType = XDataType;
...@@ -84,7 +84,7 @@ struct Layernorm2dFwd ...@@ -84,7 +84,7 @@ struct Layernorm2dFwd
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -111,7 +111,7 @@ struct Layernorm2dFwd ...@@ -111,7 +111,7 @@ struct Layernorm2dFwd
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_x_residual, hargs.p_x_residual,
hargs.p_x_scale, hargs.p_sm_scale,
hargs.p_x_bias, hargs.p_x_bias,
hargs.p_gamma, hargs.p_gamma,
hargs.p_beta, hargs.p_beta,
...@@ -171,7 +171,7 @@ struct Layernorm2dFwd ...@@ -171,7 +171,7 @@ struct Layernorm2dFwd
base_str += _SS_("_") + _SS_(t2s<YDataType>::name); base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
} }
if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) { if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<XScaleDataType>::name); base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name); base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
} }
if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) { if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
...@@ -356,18 +356,18 @@ struct Layernorm2dFwd ...@@ -356,18 +356,18 @@ struct Layernorm2dFwd
return make_null_tile_window(make_tuple(number<Block_M>{})); return make_null_tile_window(make_tuple(number<Block_M>{}));
}(); }();
auto x_scale_window = [&]() { auto sm_scale_window = [&]() {
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
const auto win_ = [&]() { const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>( const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_x_scale), static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
make_tuple(kargs.n), make_tuple(kargs.n),
number<Vector_N>{}); number<Vector_N>{});
return pad_tensor_view(tmp_0_, return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}), make_tuple(number<Block_N>{}),
sequence<false>{}); // x_scale no need pad sequence<false>{}); // sm_scale no need pad
}(); }();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0}); return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
} }
...@@ -405,7 +405,7 @@ struct Layernorm2dFwd ...@@ -405,7 +405,7 @@ struct Layernorm2dFwd
y_residual_window, y_residual_window,
mean_window, mean_window,
inv_std_window, inv_std_window,
x_scale_window, sm_scale_window,
y_scale_window, y_scale_window,
static_cast<const ComputeDataType>(kargs.epsilon), static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n, kargs.n,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass
typename YResidualWindow, typename YResidualWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename XScaleWindow, typename SmoothScaleWindow,
typename YScaleWindow, typename YScaleWindow,
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
...@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass
const YResidualWindow& y_residual_window_, const YResidualWindow& y_residual_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
const XScaleWindow& x_scale_window_, const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window, YScaleWindow& y_scale_window,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
...@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem); Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem);
} }
else else
Epilogue{}(y_window_, ln); Epilogue{}(y_window_, ln);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -15,23 +15,23 @@ template <typename XDataType_, ...@@ -15,23 +15,23 @@ template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename MeanDataType_, typename MeanDataType_,
typename InvStdDataType_, typename InvStdDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
typename BlockShape_, typename BlockShape_,
typename Traits_> typename Traits_>
struct Layernorm2dFwdPipelineProblem struct Layernorm2dFwdPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using XBiasDataType = remove_cvref_t<XBiasDataType_>; using XBiasDataType = remove_cvref_t<XBiasDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>; using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>; using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>; using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>; using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename YResidualWindow, typename YResidualWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename XScaleWindow, typename SmoothScaleWindow,
typename YScaleWindow, typename YScaleWindow,
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
...@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass
const YResidualWindow& y_residual_window_, const YResidualWindow& y_residual_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
const XScaleWindow& /*x_scale_window*/, const SmoothScaleWindow& /*sm_scale_window*/,
YScaleWindow& /*y_scale_window*/, YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
......
...@@ -8,5 +8,6 @@ ...@@ -8,5 +8,6 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace ck_tile { namespace ck_tile {
// host side args // host side args
struct Rmsnorm2dFwdHostArgs struct Rmsnorm2dFwdHostArgs
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_gamma; // [1, n], gamma, prec same as input
void* p_y; // [m, n], output, fp16/bf16 void* p_y; // [m, n], output, fp16/bf16
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
float epsilon; float epsilon;
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
}; };
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
template <typename Pipeline_> template <typename Pipeline_, typename Epilogue_>
struct Rmsnorm2dFwd struct Rmsnorm2dFwd
{ {
using Pipeline = remove_cvref_t<Pipeline_>; using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>; using YDataType = remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>; using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
// for simplicity, shortcut input/output type is same as X
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>; static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms; static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr index_t Block_M = Problem::BlockShape::Block_M; static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass; static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
...@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd ...@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd
struct Kargs struct Kargs
{ {
const void* p_x; const void* p_x;
const void* p_x_residual;
const void* p_sm_scale;
const void* p_gamma; const void* p_gamma;
void* p_y; void* p_y;
void* p_y_residual;
void* p_y_scale;
void* p_invRms; void* p_invRms;
float epsilon; float epsilon;
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
}; };
using Hargs = Rmsnorm2dFwdHostArgs; using Hargs = Rmsnorm2dFwdHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_x_residual,
hargs.p_sm_scale,
hargs.p_gamma, hargs.p_gamma,
hargs.p_y, hargs.p_y,
hargs.p_y_residual,
hargs.p_y_scale,
hargs.p_invRms, hargs.p_invRms,
hargs.epsilon, hargs.epsilon,
hargs.m, hargs.m,
hargs.n, hargs.n,
hargs.stride}; hargs.x_stride,
hargs.xr_stride,
hargs.y_stride,
hargs.yr_stride};
} }
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
...@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd ...@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; }; template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; }; template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; }; template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
// clang-format on // clang-format on
// in byte // in byte
...@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd ...@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd
CK_TILE_HOST static std::string GetName() CK_TILE_HOST static std::string GetName()
{ {
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off // clang-format off
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
auto surfix = [&] () { auto surfix = [&] () {
std::string n; std::string n;
if (kFusedAdd != Rmsnorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Rmsnorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedQuant != Rmsnorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Rmsnorm2dFusedQuantEnumName<kFusedQuant>::name;
if (kPadN) n += "_pn"; if (kPadN) n += "_pn";
if (kSaveInvRms) n += "_rms"; if (kSaveInvRms) n += "_rms";
if (kTwoPass) n += "_2p"; if (kTwoPass) n += "_2p";
return n; }(); return n; }();
#define _SS_ std::string auto prec_str = [&] () {
#define _TS_ std::to_string std::string base_str = _SS_(t2s<XDataType>::name);
return _SS_("rmsnorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" + if (!std::is_same_v<XDataType, YDataType>) {
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
}
if (kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
if (kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) {
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
return base_str;
}();
return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix; _SS_(Pipeline::name) + surfix;
#undef _SS_
#undef _TS_
// clang-format on // clang-format on
#undef _SS_
#undef _TS_
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
...@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd ...@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.x_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd ...@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0}); tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}(); }();
const auto x_residual_window = [&]() {
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XResidualDataType*>(kargs.p_x_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.xr_stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
const auto gamma_window = [&]() { const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma), static_cast<const GammaDataType*>(kargs.p_gamma),
...@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd ...@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y), static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.y_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd ...@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0}); tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}(); }();
auto y_residual_window = [&]() {
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YResidualDataType*>(kargs.p_y_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.yr_stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
auto inv_rms_window = [&]() { auto inv_rms_window = [&]() {
if constexpr(kSaveInvRms) if constexpr(kSaveInvRms)
{ {
...@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd ...@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd
return make_null_tile_window(make_tuple(number<Block_M>{})); return make_null_tile_window(make_tuple(number<Block_M>{}));
}(); }();
auto sm_scale_window = [&]() {
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
make_tuple(kargs.n),
number<Vector_N>{});
return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}),
sequence<false>{}); // sm_scale no need pad
}();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_N>{}));
}
}();
auto y_scale_window = [&]() {
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<YScaleDataType*>(kargs.p_y_scale),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
tmp_0_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}));
}
}();
__shared__ char smem[GetSmemSize()]; __shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, Pipeline{}(x_window,
x_residual_window,
gamma_window, gamma_window,
y_window, y_window,
y_residual_window,
inv_rms_window, inv_rms_window,
sm_scale_window,
y_scale_window,
static_cast<const ComputeDataType>(kargs.epsilon), static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n, kargs.n,
smem); smem,
Epilogue{});
} }
}; };
......
...@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy ...@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{ {
using P_ = BlockReduce2dProblem<typename Problem::XDataType, using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
return BlockReduce2d<P_>{}; return BlockReduce2d<P_>{};
...@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy ...@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{ {
using P_ = BlockReduce2dProblem<typename Problem::XDataType, using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
return BlockReduce2dSync<P_>{}; return BlockReduce2dSync<P_>{};
...@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy ...@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{ {
using P_ = BlockReduce2dProblem<typename Problem::XDataType, using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{}; return BlockReduce2dCrossWarpSync<P_>{};
...@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy ...@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
{ {
if constexpr(Problem::kNeedCrossWarpSync) if constexpr(Problem::kNeedCrossWarpSync)
{ {
using P_ = BlockReduce2dProblem<typename Problem::XDataType, using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
using block_reduce2d = BlockReduce2d<P_>; using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile = using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::XDataType>( decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>())); MakeXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>()); using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>; using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>; using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>; static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms; static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow> template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
YWindow& y_window, YWindow& y_window_,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window, InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window_,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem) const void* smem,
Epilogue) const
{ {
const auto x_window = const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window( const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
const auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{}; auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{}; auto reduce_sum_func = ReduceOp::Add{};
...@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
const auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
// load gamma (TODO: support no gamma?) // load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
}
}
// compute mean square each-thread->cross-lane->cross-warp // compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d( auto square_sum = block_reduce2d(acc,
x, reduce_square_sum_func.GetIdentityValue<ComputeDataType>(), reduce_square_sum_func); reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
block_reduce2d_sync(square_sum, reduce_sum_func); block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
...@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms)); store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
// rmsnorm computation // rmsnorm computation
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution()); auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]); const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]); auto rmsn_ = acc[idx] * inv_rms_[i_idx] * gamma_;
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
y(idx) = type_convert<YDataType>(y_); rmsn(idx) = rmsn_;
}); });
store_tile(y_window, y);
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, rmsn);
}
} }
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -12,25 +12,25 @@ template <typename XDataType_, ...@@ -12,25 +12,25 @@ template <typename XDataType_,
typename ComputeDataType_, typename ComputeDataType_,
typename YDataType_, typename YDataType_,
typename InvRmsDataType_, typename InvRmsDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
typename BlockShape_, typename BlockShape_,
bool kPadN_, typename Traits_>
bool kSaveInvRms_,
bool kTwoPass_>
struct Rmsnorm2dFwdPipelineProblem struct Rmsnorm2dFwdPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>; using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_; using Traits = remove_cvref_t<Traits_>;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>; using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>; using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>; static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms; static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow> template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
YWindow& y_window, YWindow& y_window,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window, InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& /*sm_scale_window_*/,
YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem) const void* smem,
Epilogue) const
{ {
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window( auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
// Problem::BlockShape // Problem::BlockShape
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
...@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using XTensorType = decltype(load_tile(x_window)); using ComputeTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>(); auto square_sum = block_reduce2d.template MakeYBlockTile<ComputeTensorType>();
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>()); set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); auto x = load_tile(x_window);
block_reduce2d(x, square_sum, reduce_square_sum_func); auto x_resi = load_tile(x_residual_window);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N});
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_reduce2d(acc, square_sum, reduce_square_sum_func);
} }
block_reduce2d_sync(square_sum, reduce_sum_func); block_reduce2d_sync(square_sum, reduce_sum_func);
...@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window});
// rmsnorm computation // rmsnorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); auto x = load_tile(x_window);
// load gamma/beta (TODO: support no gamma/beta?) auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
}
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution()); // rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]); const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]); auto rmsn_ = acc(idx) * inv_rms_[i_idx] * gamma_;
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
y(idx) = type_convert<YDataType>(y_); rmsn(idx) = rmsn_;
}); });
store_tile(y_window, y); static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
Epilogue{}(y_window, rmsn);
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N}); move_tile_window(y_window, {0, -Block_N});
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
enum class Rmsnorm2dFusedAddEnum
{
NO_ADD = 0,
// fused add before RMSNorm and store result to global
PRE_ADD_STORE = 1,
// fused add before RMSNorm, but not store result
PRE_ADD = 2,
};
// clang-format off
template<Rmsnorm2dFusedAddEnum> struct Rmsnorm2dFusedAddEnumName;
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
// clang-format on
enum class Rmsnorm2dFusedQuantEnum
{
NO_SWEEP = 0,
SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
};
// clang-format off
template<Rmsnorm2dFusedQuantEnum> struct Rmsnorm2dFusedQuantEnumName;
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
// clang-format on
template <bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_,
Rmsnorm2dFusedAddEnum kFusedAdd_,
Rmsnorm2dFusedQuantEnum kFusedQuant_>
struct Rmsnorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -12,7 +12,7 @@ namespace ck_tile { ...@@ -12,7 +12,7 @@ namespace ck_tile {
struct MoeSmoothquantHostArgs struct MoeSmoothquantHostArgs
{ {
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16 const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32 const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
const void* p_topk_ids; // [tokens, topk] const void* p_topk_ids; // [tokens, topk]
void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale
...@@ -33,11 +33,11 @@ struct MoeSmoothquant ...@@ -33,11 +33,11 @@ struct MoeSmoothquant
using Pipeline = remove_cvref_t<Pipeline_>; using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = remove_cvref_t<typename Problem::QYDataType>; using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
static constexpr index_t Block_M = Problem::BlockShape::Block_M; static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
...@@ -57,7 +57,7 @@ struct MoeSmoothquant ...@@ -57,7 +57,7 @@ struct MoeSmoothquant
struct Kargs struct Kargs
{ {
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16 const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32 const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
const void* p_topk_ids; // [tokens, topk] const void* p_topk_ids; // [tokens, topk]
void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale
...@@ -75,7 +75,7 @@ struct MoeSmoothquant ...@@ -75,7 +75,7 @@ struct MoeSmoothquant
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_xscale, hargs.p_smscale,
hargs.p_topk_ids, hargs.p_topk_ids,
hargs.p_yscale, hargs.p_yscale,
hargs.p_qy, hargs.p_qy,
...@@ -101,6 +101,7 @@ struct MoeSmoothquant ...@@ -101,6 +101,7 @@ struct MoeSmoothquant
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; }; template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; }; template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; }; template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "i8"; };
// clang-format on // clang-format on
// in byte // in byte
...@@ -118,7 +119,7 @@ struct MoeSmoothquant ...@@ -118,7 +119,7 @@ struct MoeSmoothquant
#define _SS_ std::string #define _SS_ std::string
#define _TS_ std::to_string #define _TS_ std::to_string
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" + return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" + _SS_(t2s<QYDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix; _SS_(Pipeline::name) + surfix;
...@@ -153,9 +154,10 @@ struct MoeSmoothquant ...@@ -153,9 +154,10 @@ struct MoeSmoothquant
}(); }();
// [experts, hidden_size], // [experts, hidden_size],
const auto xscale_window = [&]() { const auto smscale_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_xscale) + i_expert * kargs.hidden_size, static_cast<const SmoothScaleDataType*>(kargs.p_smscale) +
i_expert * kargs.hidden_size,
make_tuple(kargs.hidden_size), make_tuple(kargs.hidden_size),
make_tuple(1), make_tuple(1),
number<Vector_N>{}, number<Vector_N>{},
...@@ -198,7 +200,7 @@ struct MoeSmoothquant ...@@ -198,7 +200,7 @@ struct MoeSmoothquant
__shared__ char smem[GetSmemSize()]; __shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.hidden_size, smem); Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -11,11 +11,11 @@ namespace ck_tile { ...@@ -11,11 +11,11 @@ namespace ck_tile {
// host side args // host side args
struct SmoothquantHostArgs struct SmoothquantHostArgs
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_xscale; // [1, n], input, columnwise scale, fp32 const void* p_smscale; // [1, n], input, columnwise scale, fp32
void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_xscale) void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_smscale)
void* p_qy; // [m, n], output, p_x * p_xscale / p_yscale void* p_qy; // [m, n], output, p_x * p_smscale / p_yscale
index_t m; index_t m;
index_t n; index_t n;
...@@ -30,11 +30,11 @@ struct Smoothquant ...@@ -30,11 +30,11 @@ struct Smoothquant
using Pipeline = remove_cvref_t<Pipeline_>; using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = remove_cvref_t<typename Problem::QYDataType>; using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
static constexpr index_t Block_M = Problem::BlockShape::Block_M; static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
...@@ -52,7 +52,7 @@ struct Smoothquant ...@@ -52,7 +52,7 @@ struct Smoothquant
struct Kargs struct Kargs
{ {
const void* p_x; const void* p_x;
const void* p_xscale; const void* p_smscale;
void* p_yscale; void* p_yscale;
void* p_qy; void* p_qy;
...@@ -67,7 +67,7 @@ struct Smoothquant ...@@ -67,7 +67,7 @@ struct Smoothquant
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_xscale, hargs.p_smscale,
hargs.p_yscale, hargs.p_yscale,
hargs.p_qy, hargs.p_qy,
hargs.m, hargs.m,
...@@ -134,9 +134,9 @@ struct Smoothquant ...@@ -134,9 +134,9 @@ struct Smoothquant
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0}); tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}(); }();
const auto xscale_window = [&]() { const auto smscale_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_xscale), static_cast<const SmoothScaleDataType*>(kargs.p_smscale),
make_tuple(kargs.n), make_tuple(kargs.n),
make_tuple(1), make_tuple(1),
number<Vector_N>{}, number<Vector_N>{},
...@@ -177,7 +177,7 @@ struct Smoothquant ...@@ -177,7 +177,7 @@ struct Smoothquant
__shared__ char smem[GetSmemSize()]; __shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.n, smem); Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.n, smem);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy ...@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXScaleBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeSmoothScaleBlockTileDistribution()
{ {
using S = typename Problem::BlockShape; using S = typename Problem::BlockShape;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -16,11 +16,11 @@ struct SmoothquantPipelineOnePass ...@@ -16,11 +16,11 @@ struct SmoothquantPipelineOnePass
using Problem = ck_tile::remove_cvref_t<Problem_>; using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>; using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
...@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass ...@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow> template <typename XWindow,
typename SmoothScaleWindow,
typename QYWindow,
typename YScaleWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XScaleWindow& xscale_window_, const SmoothScaleWindow& smscale_window_,
YScaleWindow& yscale_window, YScaleWindow& yscale_window,
QYWindow& qy_window, QYWindow& qy_window,
ck_tile::index_t, ck_tile::index_t,
...@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass ...@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass
{ {
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto xscale_window = make_tile_window( auto smscale_window = make_tile_window(
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>()); smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
auto reduce_absmax_func = ReduceOp::AbsMax{}; auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) { auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
...@@ -67,14 +70,14 @@ struct SmoothquantPipelineOnePass ...@@ -67,14 +70,14 @@ struct SmoothquantPipelineOnePass
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
const auto x = load_tile(x_window); const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window); const auto smscale = load_tile(smscale_window);
auto y = tile_elementwise_in( auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) { [&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b); return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
}, },
x, x,
xscale); smscale);
// compute absmax, cross-lane->cross-warp // compute absmax, cross-lane->cross-warp
auto absmax = [&]() { auto absmax = [&]() {
...@@ -110,7 +113,7 @@ struct SmoothquantPipelineOnePass ...@@ -110,7 +113,7 @@ struct SmoothquantPipelineOnePass
sweep_tile(qy, [&](auto idx) { sweep_tile(qy, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
auto qy_ = y[idx] / yscale[i_idx]; auto qy_ = y[idx] / yscale[i_idx];
qy(idx) = saturates<QYDataType>{}(qy_); qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
}); });
store_tile(qy_window, qy); store_tile(qy_window, qy);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment