Unverified Commit a11cf2c6 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen_hiprtc

parents a72e9efa 64d5c4d6
......@@ -70,11 +70,16 @@ struct FusedMoeGemmPipeline_FlatmmUk
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
#if 1
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_0, max(smem_1, smem_bridge));
return max(smem_0 + smem_1, smem_bridge);
#else
// keep it here purposely in case we have regression
return 65536;
#endif
}
// this is the thread-offset along row/col
......@@ -125,6 +130,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
array<index_t, n_size> row_ids;
static_for<0, n_size, 1>{}([&](auto i) {
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
row_ids.at(i) &= 0xffffff;
#endif
});
return row_ids;
......@@ -164,9 +172,12 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t sorted_tile_id,
index_t intermediate_tile_id)
{
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size;
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 =
kargs.intermediate_size * hidden_radio_0; // total gate+up
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size;
// after weight shuffling, gate-only: [nr0, kr0, w0], gate+up: [nr0_gate + nr0_up, kr0, w0]
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
......@@ -200,29 +211,35 @@ struct FusedMoeGemmPipeline_FlatmmUk
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
auto g_win = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
auto make_gu_win = [&](const auto* ptr_) {
auto view_ = make_naive_tensor_view<address_space_enum::global>(
ptr_,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{},
number<1>{});
auto g_window_ = make_tile_window_linear_raw(
g_view_,
auto win_ = make_tile_window_linear_raw(
view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{});
return g_window_;
}();
return win_;
};
const GDataType* gu_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
auto g_win = make_gu_win(gu_ptr);
// Note: gu swizzled, [nr_u+nr_g, kr, w], hence base offset to up is just interm*hidden
auto u_win = make_gu_win(gu_ptr + kargs.intermediate_size * kargs.hidden_size);
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
number<decltype(g_win)::NumAccess_NonLinear>{});
......@@ -309,28 +326,73 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res,
a_coords,
g_res,
g_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
sweep_tile(
acc_0,
[&](auto idx0, auto idx1) {
fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
typename Problem::GateActivation{}(v_, v_);
acc_0(idx0) = v_.x;
acc_0(idx1) = v_.y;
},
sequence<1, 2>{});
auto y_pre = cast_tile<YDataType>(acc_0);
auto uk_0 = Policy::template GetUK_0<Problem>();
auto y_pre = [&]() {
if constexpr(IsGateOnly)
{
auto acc_0 = uk_0(a_res,
a_coords,
g_res,
g_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
sweep_tile(
acc_0,
[&](auto idx0, auto idx1) {
fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
typename Problem::GateActivation{}(v_, v_);
acc_0(idx0) = v_.x;
acc_0(idx1) = v_.y;
},
sequence<1, 2>{});
return cast_tile<YDataType>(acc_0);
}
else
{
uint32x8_t gu_res;
gu_res[0] = g_res[0];
gu_res[1] = g_res[1];
gu_res[2] = g_res[2];
gu_res[3] = g_res[3];
gu_res[4] = u_res[0];
gu_res[5] = u_res[1];
gu_res[6] = u_res[2];
gu_res[7] = u_res[3];
auto acc_0 = uk_0(a_res,
a_coords,
gu_res,
g_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * BlockShape::Block_W0,
bool_constant<true>{}); // tile offset for B matrix each unroll
sweep_tile(
acc_0.at(number<0>{}),
[&](auto idx0, auto idx1) {
fp32x2_t v_{acc_0.at(number<0>{})(idx0), acc_0.at(number<0>{})(idx1)};
typename Problem::GateActivation{}(v_, v_);
acc_0.at(number<0>{})(idx0) = v_.x;
acc_0.at(number<0>{})(idx1) = v_.y;
},
sequence<1, 2>{});
auto reduced_acc_0 =
tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
acc_0.at(number<0>{}),
acc_0.at(number<1>{}));
return cast_tile<YDataType>(reduced_acc_0);
}
}();
block_sync_lds();
......
......@@ -101,9 +101,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -174,7 +174,7 @@ struct GemmKernel
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
{
return false;
}
......@@ -185,7 +185,7 @@ struct GemmKernel
}
else
{
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false)
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
return false;
}
......@@ -197,7 +197,7 @@ struct GemmKernel
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false)
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
return false;
}
......@@ -208,7 +208,7 @@ struct GemmKernel
}
else
{
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
{
return false;
}
......@@ -220,7 +220,7 @@ struct GemmKernel
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false)
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
return false;
}
......@@ -231,7 +231,7 @@ struct GemmKernel
}
else
{
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false)
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
return false;
}
......@@ -323,17 +323,17 @@ struct GemmKernel
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadM, false>{});
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
......@@ -341,17 +341,17 @@ struct GemmKernel
const auto& b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadN, false>{});
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
......@@ -359,17 +359,17 @@ struct GemmKernel
const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<false, GemmPipeline::kPadN>{});
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<GemmPipeline::kPadM, false>{});
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
......@@ -383,19 +383,19 @@ struct GemmKernel
const auto& a_pad_view = views.at(I0);
const auto& a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
const auto& b_pad_view = views.at(I1);
const auto& b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
const auto& c_pad_view = views.at(I2);
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, b_block_window, c_block_window);
......@@ -426,7 +426,7 @@ struct GemmKernel
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
;
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
......@@ -456,7 +456,10 @@ struct GemmKernel
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr =
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockGemmShape_>
struct GemmTilePartitioner
/** @brief Struct representing 2D block index mapping into 3D output tile space. */
template <typename BlockGemmShapeType>
struct GemmTile2DPartitioner
{
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
static constexpr index_t kM = BlockGemmShape::kM;
static constexpr index_t kN = BlockGemmShape::kN;
static constexpr index_t kK = BlockGemmShape::kK;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size)
/** @brief Returns 3D grid size. */
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) noexcept(
noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
{
index_t GridDimX = (M + kM - 1) / kM;
index_t GridDimY = (N + kN - 1) / kN;
index_t GridDimZ = batch_size;
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
const index_t GridDimZ = batch_size;
return dim3(GridDimX, GridDimY, GridDimZ);
}
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K)
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
*/
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
{
return integer_divide_ceil(K, kK);
return integer_divide_ceil(K, KPerBlock);
}
CK_TILE_DEVICE auto operator()()
/**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x
* @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes.
*/
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx,
index_t blockIdy) noexcept
-> const tuple<index_t, index_t>
{
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kM);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kN);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy);
return make_tuple(iM, iN);
}
};
template <typename BlockGemmShape_>
/**
* @brief Struct representing 1D block index mapping into 2D output tile space.
*/
template <typename BlockGemmShapeType>
struct GemmTile1DPartitioner
{
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N)
/** @brief delete default ctr with no any object */
constexpr GemmTile1DPartitioner() noexcept = delete;
/** @brief constructs an object that does contain a N value. */
constexpr GemmTile1DPartitioner(index_t N) noexcept { N_ = N; }
/** @brief Returns 1D grid size. */
CK_TILE_HOST static constexpr auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
{
index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
return dim3(GridDimX * GridDimY, 1, 1);
}
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N)
/**
* @brief Returns the number of blocks in N.
* @param [in] N is dimension
*/
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) noexcept -> index_t
{
return integer_divide_ceil(N, NPerBlock);
}
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K)
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
*/
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
{
return integer_divide_ceil(K, KPerBlock);
}
CK_TILE_DEVICE auto operator()(index_t blockOffset, index_t NBlockSize)
/**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x - block_start.
* */
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx) noexcept
-> const tuple<index_t, index_t>
{
const index_t NBlock = GetNBlock(N_);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - (iM)*NBlock);
return make_tuple(iM, iN);
}
private:
CK_TILE_DEVICE static index_t N_;
};
/**
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::false specialization,
* checking expression validity in-place for ill-formed.
*/
template <typename, typename = void>
struct HasFnOneArgImpl : std::false_type
{
};
/**
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::true specialization,
* checking expression validity in-place for well-formed.
* @note: `1` - a constant value indicating the number of parameters in the function.
*/
template <typename T>
struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIndex(1))>>
: std::true_type
{
};
/**
* @brief Struct used to calculate offseted tile indexes.
* @note: The struct supports the 1D-Partitioner mechanism,
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type.
*/
template <typename PartitionerFn,
typename = typename std::enable_if_t<HasFnOneArgImpl<PartitionerFn>{}>>
struct OffsettedTile1DPartitioner
{
/**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start is `blockIdx.x - block_start`.
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index.
*/
[[nodiscard]] CK_TILE_DEVICE static constexpr auto GetOffsetedTileIndex(index_t block_start,
index_t N) noexcept
-> const tuple<index_t, index_t>
{
index_t iM = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) /
GetNBlock(NBlockSize) * MPerBlock);
index_t iN = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) %
GetNBlock(NBlockSize) * NPerBlock);
const auto [iM, iN] = PartitionerFn(N).GetOutputTileIndex(blockIdx.x - block_start);
return make_tuple(iM, iN);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/host.hpp"
namespace ck_tile {
struct GroupedGemmHostArgs
struct GroupedGemmHostArgs : public ck_tile::GemmHostArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
CK_TILE_HOST GroupedGemmHostArgs() noexcept = default;
CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_)
: GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, KBatch, M_, N_, K_, stride_A_, stride_B_, stride_C_)
{
}
private:
static constexpr index_t KBatch = 1;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GroupedGemmKernel
struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmKernelArgs = typename Base::GemmKernelArgs;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t KBatch = 1;
struct GemmTransKernelArg
{
GroupedGemmHostArgs group_karg;
GemmKernelArgs group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;
GemmTransKernelArg() = default;
GemmTransKernelArg(GroupedGemmHostArgs&& karg, index_t bl_start, index_t bl_end)
GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{
}
};
__host__ static size_t GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
__host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::size_t
{
return gemm_descs.size() * sizeof(GemmTransKernelArg);
}
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
using Hargs = GroupedGemmHostArgs;
__host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
__host__ static constexpr auto GridSize(const std::vector<Hargs>& gemm_descs)
__host__ static constexpr auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
{
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
......@@ -77,7 +84,8 @@ struct GroupedGemmKernel
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto MakeKargs(const std::vector<Hargs>& gemm_descs)
CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::vector<GemmTransKernelArg>
{
std::vector<GemmTransKernelArg> gemm_kernel_args_;
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
......@@ -100,22 +108,23 @@ struct GroupedGemmKernel
const index_t stride_c = gemm_descs[i].stride_C;
const auto dim3 = TilePartitioner::GridSize(M, N);
const index_t grid_size_grp = dim3.x * 1 * 1;
const index_t grid_size_grp = dim3.x;
const index_t block_start = grid_size;
const index_t block_end = grid_size + grid_size_grp;
grid_size += grid_size_grp;
auto karg = GroupedGemmHostArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
type_convert<CDataType*>(gemm_descs[i].c_ptr),
M,
N,
K,
stride_a,
stride_b,
stride_c};
auto karg = GemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
type_convert<CDataType*>(gemm_descs[i].c_ptr),
M,
N,
K,
stride_a,
stride_b,
stride_c,
KBatch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
......@@ -123,162 +132,34 @@ struct GroupedGemmKernel
return gemm_kernel_args_;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void Run(const Hargs& kargs, const index_t block_start) const
CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}(block_start, kargs.N);
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<1>{},
number<1>{});
}
}();
const auto [iM, iN] =
OffsetTile1DPartitioner::GetOffsetedTileIndex(kargs.block_start, kargs.group_karg.N);
auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<1>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{},
number<1>{});
}
}();
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
auto a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs.group_karg, blockIdx.z);
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
auto b_pad_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.group_karg.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.group_karg.b_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.group_karg.c_ptr);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto CBlockWindow_pad = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
this->RunGemm(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs.group_karg, splitk_batch_offset, i_m, i_n);
}
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
int group_count) const
index_t group_count) const
{
const index_t block_id = ck_tile::get_block_1d_id();
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
......@@ -286,7 +167,7 @@ struct GroupedGemmKernel
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
index_t group_id = index_t((left + right) >> 1);
while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
block_id < gemm_desc_ptr[group_id].block_end)) &&
......@@ -300,10 +181,10 @@ struct GroupedGemmKernel
{
left = group_id;
}
group_id = index_t((left + right) / 2);
group_id = index_t((left + right) >> 1);
}
Run(gemm_desc_ptr[group_id].group_karg, gemm_desc_ptr[group_id].block_start);
Run(gemm_desc_ptr[group_id]);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -14,7 +14,7 @@ struct Layernorm2dFwdHostArgs
{
const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input
......@@ -43,16 +43,16 @@ struct Layernorm2dFwd
using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
// for simplicity, shortcut input/output type is same as X
using XResidualDataType = XDataType;
......@@ -84,7 +84,7 @@ struct Layernorm2dFwd
{
const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input
......@@ -111,7 +111,7 @@ struct Layernorm2dFwd
{
return Kargs{hargs.p_x,
hargs.p_x_residual,
hargs.p_x_scale,
hargs.p_sm_scale,
hargs.p_x_bias,
hargs.p_gamma,
hargs.p_beta,
......@@ -171,7 +171,7 @@ struct Layernorm2dFwd
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
}
if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<XScaleDataType>::name);
base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
......@@ -356,18 +356,18 @@ struct Layernorm2dFwd
return make_null_tile_window(make_tuple(number<Block_M>{}));
}();
auto x_scale_window = [&]() {
auto sm_scale_window = [&]() {
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_x_scale),
static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
make_tuple(kargs.n),
number<Vector_N>{});
return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}),
sequence<false>{}); // x_scale no need pad
sequence<false>{}); // sm_scale no need pad
}();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
}
......@@ -405,7 +405,7 @@ struct Layernorm2dFwd
y_residual_window,
mean_window,
inv_std_window,
x_scale_window,
sm_scale_window,
y_scale_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass
typename YResidualWindow,
typename MeanWindow,
typename InvStdWindow,
typename XScaleWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
......@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass
const YResidualWindow& y_residual_window_,
MeanWindow& mean_window,
InvStdWindow& inv_std_window,
const XScaleWindow& x_scale_window_,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
......@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem);
Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem);
}
else
Epilogue{}(y_window_, ln);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -15,23 +15,23 @@ template <typename XDataType_,
typename YDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename XScaleDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
typename BlockShape_,
typename Traits_>
struct Layernorm2dFwdPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using XBiasDataType = remove_cvref_t<XBiasDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using XDataType = remove_cvref_t<XDataType_>;
using XBiasDataType = remove_cvref_t<XBiasDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename YResidualWindow,
typename MeanWindow,
typename InvStdWindow,
typename XScaleWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
......@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass
const YResidualWindow& y_residual_window_,
MeanWindow& mean_window,
InvStdWindow& inv_std_window,
const XScaleWindow& /*x_scale_window*/,
const SmoothScaleWindow& /*sm_scale_window*/,
YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon,
ck_tile::index_t row_size,
......
......@@ -8,5 +8,6 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace ck_tile {
// host side args
struct Rmsnorm2dFwdHostArgs
{
const void* p_x; // [m ,n], input, fp16/bf16
const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_gamma; // [1, n], gamma, prec same as input
void* p_y; // [m, n], output, fp16/bf16
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
void* p_y; // [m, n], output, fp16/bf16
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
float epsilon;
index_t m;
index_t n;
index_t stride; // row_stride
index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
};
// TODO: Extract some type to wrapper class
template <typename Pipeline_>
template <typename Pipeline_, typename Epilogue_>
struct Rmsnorm2dFwd
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
// for simplicity, shortcut input/output type is same as X
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
......@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd
struct Kargs
{
const void* p_x;
const void* p_x_residual;
const void* p_sm_scale;
const void* p_gamma;
void* p_y;
void* p_y_residual;
void* p_y_scale;
void* p_invRms;
float epsilon;
index_t m;
index_t n;
index_t stride; // row_stride
index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
};
using Hargs = Rmsnorm2dFwdHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_x,
hargs.p_x_residual,
hargs.p_sm_scale,
hargs.p_gamma,
hargs.p_y,
hargs.p_y_residual,
hargs.p_y_scale,
hargs.p_invRms,
hargs.epsilon,
hargs.m,
hargs.n,
hargs.stride};
hargs.x_stride,
hargs.xr_stride,
hargs.y_stride,
hargs.yr_stride};
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
......@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
// in byte
......@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd
CK_TILE_HOST static std::string GetName()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using S_ = typename Problem::BlockShape;
auto surfix = [&] () {
std::string n;
if (kFusedAdd != Rmsnorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Rmsnorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedQuant != Rmsnorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Rmsnorm2dFusedQuantEnumName<kFusedQuant>::name;
if (kPadN) n += "_pn";
if (kSaveInvRms) n += "_rms";
if (kTwoPass) n += "_2p";
return n; }();
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("rmsnorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
auto prec_str = [&] () {
std::string base_str = _SS_(t2s<XDataType>::name);
if (!std::is_same_v<XDataType, YDataType>) {
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
}
if (kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
if (kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) {
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
return base_str;
}();
return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
#undef _SS_
#undef _TS_
// clang-format on
#undef _SS_
#undef _TS_
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
......@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
make_tuple(kargs.x_stride, 1),
number<Vector_N>{},
number<1>{});
......@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto x_residual_window = [&]() {
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XResidualDataType*>(kargs.p_x_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.xr_stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
......@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
make_tuple(kargs.y_stride, 1),
number<Vector_N>{},
number<1>{});
......@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
auto y_residual_window = [&]() {
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YResidualDataType*>(kargs.p_y_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.yr_stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
auto inv_rms_window = [&]() {
if constexpr(kSaveInvRms)
{
......@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd
return make_null_tile_window(make_tuple(number<Block_M>{}));
}();
auto sm_scale_window = [&]() {
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
make_tuple(kargs.n),
number<Vector_N>{});
return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}),
sequence<false>{}); // sm_scale no need pad
}();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_N>{}));
}
}();
auto y_scale_window = [&]() {
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<YScaleDataType*>(kargs.p_y_scale),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
tmp_0_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}));
}
}();
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window,
x_residual_window,
gamma_window,
y_window,
y_residual_window,
inv_rms_window,
sm_scale_window,
y_scale_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n,
smem);
smem,
Epilogue{});
}
};
......
......@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2d<P_>{};
......@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dSync<P_>{};
......@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
......@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
......@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow>
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
YWindow& y_window,
YWindow& y_window_,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window_,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
void* smem,
Epilogue) const
{
const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
const auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
......@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
const auto x = load_tile(x_window);
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
}
}
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d(
x, reduce_square_sum_func.GetIdentityValue<ComputeDataType>(), reduce_square_sum_func);
auto square_sum = block_reduce2d(acc,
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
......@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
// rmsnorm computation
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
auto rmsn_ = acc[idx] * inv_rms_[i_idx] * gamma_;
y(idx) = type_convert<YDataType>(y_);
rmsn(idx) = rmsn_;
});
store_tile(y_window, y);
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, rmsn);
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -12,25 +12,25 @@ template <typename XDataType_,
typename ComputeDataType_,
typename YDataType_,
typename InvRmsDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
typename Traits_>
struct Rmsnorm2dFwdPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
using Traits = remove_cvref_t<Traits_>;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
......@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow>
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
YWindow& y_window,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& /*sm_scale_window_*/,
YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
void* smem,
Epilogue) const
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
// Problem::BlockShape
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
......@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using XTensorType = decltype(load_tile(x_window));
auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>();
using ComputeTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto square_sum = block_reduce2d.template MakeYBlockTile<ComputeTensorType>();
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
block_reduce2d(x, square_sum, reduce_square_sum_func);
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N});
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_reduce2d(acc, square_sum, reduce_square_sum_func);
}
block_reduce2d_sync(square_sum, reduce_sum_func);
......@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window});
// rmsnorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
// load gamma/beta (TODO: support no gamma/beta?)
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
}
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
// rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
auto rmsn_ = acc(idx) * inv_rms_[i_idx] * gamma_;
y(idx) = type_convert<YDataType>(y_);
rmsn(idx) = rmsn_;
});
store_tile(y_window, y);
static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
Epilogue{}(y_window, rmsn);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
enum class Rmsnorm2dFusedAddEnum
{
NO_ADD = 0,
// fused add before RMSNorm and store result to global
PRE_ADD_STORE = 1,
// fused add before RMSNorm, but not store result
PRE_ADD = 2,
};
// clang-format off
template<Rmsnorm2dFusedAddEnum> struct Rmsnorm2dFusedAddEnumName;
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
// clang-format on
enum class Rmsnorm2dFusedQuantEnum
{
NO_SWEEP = 0,
SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
};
// clang-format off
template<Rmsnorm2dFusedQuantEnum> struct Rmsnorm2dFusedQuantEnumName;
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
// clang-format on
template <bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_,
Rmsnorm2dFusedAddEnum kFusedAdd_,
Rmsnorm2dFusedQuantEnum kFusedQuant_>
struct Rmsnorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -12,7 +12,7 @@ namespace ck_tile {
struct MoeSmoothquantHostArgs
{
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32
const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
const void* p_topk_ids; // [tokens, topk]
void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale
......@@ -33,11 +33,11 @@ struct MoeSmoothquant
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
......@@ -57,7 +57,7 @@ struct MoeSmoothquant
struct Kargs
{
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32
const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
const void* p_topk_ids; // [tokens, topk]
void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale
......@@ -75,7 +75,7 @@ struct MoeSmoothquant
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_x,
hargs.p_xscale,
hargs.p_smscale,
hargs.p_topk_ids,
hargs.p_yscale,
hargs.p_qy,
......@@ -101,6 +101,7 @@ struct MoeSmoothquant
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "i8"; };
// clang-format on
// in byte
......@@ -118,7 +119,7 @@ struct MoeSmoothquant
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" +
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" + _SS_(t2s<QYDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
......@@ -153,9 +154,10 @@ struct MoeSmoothquant
}();
// [experts, hidden_size],
const auto xscale_window = [&]() {
const auto smscale_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_xscale) + i_expert * kargs.hidden_size,
static_cast<const SmoothScaleDataType*>(kargs.p_smscale) +
i_expert * kargs.hidden_size,
make_tuple(kargs.hidden_size),
make_tuple(1),
number<Vector_N>{},
......@@ -198,7 +200,7 @@ struct MoeSmoothquant
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -11,11 +11,11 @@ namespace ck_tile {
// host side args
struct SmoothquantHostArgs
{
const void* p_x; // [m ,n], input, fp16/bf16
const void* p_xscale; // [1, n], input, columnwise scale, fp32
const void* p_x; // [m ,n], input, fp16/bf16
const void* p_smscale; // [1, n], input, columnwise scale, fp32
void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_xscale)
void* p_qy; // [m, n], output, p_x * p_xscale / p_yscale
void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_smscale)
void* p_qy; // [m, n], output, p_x * p_smscale / p_yscale
index_t m;
index_t n;
......@@ -30,11 +30,11 @@ struct Smoothquant
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
......@@ -52,7 +52,7 @@ struct Smoothquant
struct Kargs
{
const void* p_x;
const void* p_xscale;
const void* p_smscale;
void* p_yscale;
void* p_qy;
......@@ -67,7 +67,7 @@ struct Smoothquant
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_x,
hargs.p_xscale,
hargs.p_smscale,
hargs.p_yscale,
hargs.p_qy,
hargs.m,
......@@ -134,9 +134,9 @@ struct Smoothquant
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto xscale_window = [&]() {
const auto smscale_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_xscale),
static_cast<const SmoothScaleDataType*>(kargs.p_smscale),
make_tuple(kargs.n),
make_tuple(1),
number<Vector_N>{},
......@@ -177,7 +177,7 @@ struct Smoothquant
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.n, smem);
Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.n, smem);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXScaleBlockTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeSmoothScaleBlockTileDistribution()
{
using S = typename Problem::BlockShape;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -16,11 +16,11 @@ struct SmoothquantPipelineOnePass
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
......@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow>
template <typename XWindow,
typename SmoothScaleWindow,
typename QYWindow,
typename YScaleWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XScaleWindow& xscale_window_,
const SmoothScaleWindow& smscale_window_,
YScaleWindow& yscale_window,
QYWindow& qy_window,
ck_tile::index_t,
......@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto xscale_window = make_tile_window(
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
auto smscale_window = make_tile_window(
smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
......@@ -67,14 +70,14 @@ struct SmoothquantPipelineOnePass
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window);
auto y = tile_elementwise_in(
const auto x = load_tile(x_window);
const auto smscale = load_tile(smscale_window);
auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
},
x,
xscale);
smscale);
// compute absmax, cross-lane->cross-warp
auto absmax = [&]() {
......@@ -110,7 +113,7 @@ struct SmoothquantPipelineOnePass
sweep_tile(qy, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
auto qy_ = y[idx] / yscale[i_idx];
qy(idx) = saturates<QYDataType>{}(qy_);
qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
});
store_tile(qy_window, qy);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment