Unverified Commit 04dd3148 authored by ruanjm's avatar ruanjm Committed by GitHub
Browse files

[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)



* Add shortcut to RMSNorm

* Modify test for adding shortcut for RMSNorm

* Add fused parameter into tests

* 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp

* 1. Supports various stride and percisions.

* Add support of Epilogue

* Add fuse and epilogue support to rmsnorm ref

* Modify rmsnorm example

* Refactor tests/examples

* Bug fix for newly added tests/examples

* Bug fix for new tests 2

* Modify smoke test scripts

remove dbg code

* Supports non-smooth dyanmic quant

* Update Rmsnorm2dFwd::GetName()

* rename xscale and prec_sx to smoothscale and prec_sm

Bug fix after rename

Remove files

* change example_rmsnorm2d_fwd.cpp

* update performance calculator

* Fix issue in two-pass when fuse add is enabled

* Remove comment of beta

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
parent c0b90f13
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -24,7 +24,7 @@ struct DynamicQuantEpilogueTraits ...@@ -24,7 +24,7 @@ struct DynamicQuantEpilogueTraits
// this epilogue just store out a M*N matrix, row major // this epilogue just store out a M*N matrix, row major
template <typename AccDataType_, template <typename AccDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
typename ODataType_, typename ODataType_,
typename BlockShape_, typename BlockShape_,
...@@ -32,7 +32,7 @@ template <typename AccDataType_, ...@@ -32,7 +32,7 @@ template <typename AccDataType_,
struct DynamicQuantEpilogueProblem struct DynamicQuantEpilogueProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>; using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
...@@ -45,7 +45,7 @@ struct DynamicQuantEpilogue ...@@ -45,7 +45,7 @@ struct DynamicQuantEpilogue
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockShape = remove_cvref_t<typename Problem::BlockShape>; using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
...@@ -78,7 +78,7 @@ struct DynamicQuantEpilogue ...@@ -78,7 +78,7 @@ struct DynamicQuantEpilogue
#if 0 #if 0
// don't remove this // don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail // Note that if we set encoding purposely like this, you will result in compile fail
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length) // TODO: sm_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>, sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
...@@ -105,14 +105,8 @@ struct DynamicQuantEpilogue ...@@ -105,14 +105,8 @@ struct DynamicQuantEpilogue
return reduce_crosswarp_sync.GetSmemSize(); return reduce_crosswarp_sync.GetSmemSize();
} }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
// how do we fix this ? CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp,
template <typename ODramWindowTmp,
typename XScaleWindow,
typename YScaleWindow,
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const XScaleWindow& x_scale_window_,
YScaleWindow& y_scale_window, YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile, const OAccTile& o_acc_tile,
void* smem) void* smem)
...@@ -120,19 +114,9 @@ struct DynamicQuantEpilogue ...@@ -120,19 +114,9 @@ struct DynamicQuantEpilogue
auto reduce = GetBlockReduce2d(); auto reduce = GetBlockReduce2d();
auto reduce_sync = GetBlockReduce2dSync(); auto reduce_sync = GetBlockReduce2dSync();
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
const auto x_scale_window =
make_tile_window(x_scale_window_, MakeSmoothInputScaleTileDistribution());
auto x_scale = load_tile(x_scale_window);
auto o_acc_tmp = o_acc_tile; auto o_acc_tmp = o_acc_tile;
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<AccDataType>(x_scale[j_idx]);
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
});
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
auto row_absmax = [&]() { auto row_absmax = [&]() {
...@@ -184,5 +168,45 @@ struct DynamicQuantEpilogue ...@@ -184,5 +168,45 @@ struct DynamicQuantEpilogue
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp)); store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
} }
} }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
// Smooth Dynamic Quant
template <typename ODramWindowTmp,
typename SmoothScaleWindow,
typename YScaleWindow,
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
const auto sm_scale_window =
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
auto sm_scale = load_tile(sm_scale_window);
auto o_acc_tmp = o_acc_tile;
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<AccDataType>(sm_scale[j_idx]);
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
});
Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem);
}
// Dynamic Quant
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem);
}
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,7 @@ struct Layernorm2dFwdHostArgs ...@@ -14,7 +14,7 @@ struct Layernorm2dFwdHostArgs
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -51,7 +51,7 @@ struct Layernorm2dFwd ...@@ -51,7 +51,7 @@ struct Layernorm2dFwd
using YDataType = remove_cvref_t<typename Problem::YDataType>; using YDataType = remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>; using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>; using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
// for simplicity, shortcut input/output type is same as X // for simplicity, shortcut input/output type is same as X
...@@ -84,7 +84,7 @@ struct Layernorm2dFwd ...@@ -84,7 +84,7 @@ struct Layernorm2dFwd
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -111,7 +111,7 @@ struct Layernorm2dFwd ...@@ -111,7 +111,7 @@ struct Layernorm2dFwd
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_x_residual, hargs.p_x_residual,
hargs.p_x_scale, hargs.p_sm_scale,
hargs.p_x_bias, hargs.p_x_bias,
hargs.p_gamma, hargs.p_gamma,
hargs.p_beta, hargs.p_beta,
...@@ -171,7 +171,7 @@ struct Layernorm2dFwd ...@@ -171,7 +171,7 @@ struct Layernorm2dFwd
base_str += _SS_("_") + _SS_(t2s<YDataType>::name); base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
} }
if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) { if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<XScaleDataType>::name); base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name); base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
} }
if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) { if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
...@@ -356,18 +356,18 @@ struct Layernorm2dFwd ...@@ -356,18 +356,18 @@ struct Layernorm2dFwd
return make_null_tile_window(make_tuple(number<Block_M>{})); return make_null_tile_window(make_tuple(number<Block_M>{}));
}(); }();
auto x_scale_window = [&]() { auto sm_scale_window = [&]() {
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
const auto win_ = [&]() { const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>( const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_x_scale), static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
make_tuple(kargs.n), make_tuple(kargs.n),
number<Vector_N>{}); number<Vector_N>{});
return pad_tensor_view(tmp_0_, return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}), make_tuple(number<Block_N>{}),
sequence<false>{}); // x_scale no need pad sequence<false>{}); // sm_scale no need pad
}(); }();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0}); return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
} }
...@@ -405,7 +405,7 @@ struct Layernorm2dFwd ...@@ -405,7 +405,7 @@ struct Layernorm2dFwd
y_residual_window, y_residual_window,
mean_window, mean_window,
inv_std_window, inv_std_window,
x_scale_window, sm_scale_window,
y_scale_window, y_scale_window,
static_cast<const ComputeDataType>(kargs.epsilon), static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n, kargs.n,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass
typename YResidualWindow, typename YResidualWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename XScaleWindow, typename SmoothScaleWindow,
typename YScaleWindow, typename YScaleWindow,
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
...@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass
const YResidualWindow& y_residual_window_, const YResidualWindow& y_residual_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
const XScaleWindow& x_scale_window_, const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window, YScaleWindow& y_scale_window,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
...@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem); Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem);
} }
else else
Epilogue{}(y_window_, ln); Epilogue{}(y_window_, ln);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -15,7 +15,7 @@ template <typename XDataType_, ...@@ -15,7 +15,7 @@ template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename MeanDataType_, typename MeanDataType_,
typename InvStdDataType_, typename InvStdDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
typename BlockShape_, typename BlockShape_,
typename Traits_> typename Traits_>
...@@ -29,7 +29,7 @@ struct Layernorm2dFwdPipelineProblem ...@@ -29,7 +29,7 @@ struct Layernorm2dFwdPipelineProblem
using YDataType = remove_cvref_t<YDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>; using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>; using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>; using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename YResidualWindow, typename YResidualWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename XScaleWindow, typename SmoothScaleWindow,
typename YScaleWindow, typename YScaleWindow,
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
...@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass
const YResidualWindow& y_residual_window_, const YResidualWindow& y_residual_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
const XScaleWindow& /*x_scale_window*/, const SmoothScaleWindow& /*sm_scale_window*/,
YScaleWindow& /*y_scale_window*/, YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
......
...@@ -8,5 +8,6 @@ ...@@ -8,5 +8,6 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -12,23 +13,31 @@ namespace ck_tile { ...@@ -12,23 +13,31 @@ namespace ck_tile {
struct Rmsnorm2dFwdHostArgs struct Rmsnorm2dFwdHostArgs
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
void* p_y; // [m, n], output, fp16/bf16 void* p_y; // [m, n], output, fp16/bf16
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
float epsilon; float epsilon;
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
}; };
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
template <typename Pipeline_> template <typename Pipeline_, typename Epilogue_>
struct Rmsnorm2dFwd struct Rmsnorm2dFwd
{ {
using Pipeline = remove_cvref_t<Pipeline_>; using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
...@@ -36,15 +45,23 @@ struct Rmsnorm2dFwd ...@@ -36,15 +45,23 @@ struct Rmsnorm2dFwd
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>; using YDataType = remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>; using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
// for simplicity, shortcut input/output type is same as X
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>; static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms; static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr index_t Block_M = Problem::BlockShape::Block_M; static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass; static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
...@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd ...@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd
struct Kargs struct Kargs
{ {
const void* p_x; const void* p_x;
const void* p_x_residual;
const void* p_sm_scale;
const void* p_gamma; const void* p_gamma;
void* p_y; void* p_y;
void* p_y_residual;
void* p_y_scale;
void* p_invRms; void* p_invRms;
float epsilon; float epsilon;
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
}; };
using Hargs = Rmsnorm2dFwdHostArgs; using Hargs = Rmsnorm2dFwdHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_x_residual,
hargs.p_sm_scale,
hargs.p_gamma, hargs.p_gamma,
hargs.p_y, hargs.p_y,
hargs.p_y_residual,
hargs.p_y_scale,
hargs.p_invRms, hargs.p_invRms,
hargs.epsilon, hargs.epsilon,
hargs.m, hargs.m,
hargs.n, hargs.n,
hargs.stride}; hargs.x_stride,
hargs.xr_stride,
hargs.y_stride,
hargs.yr_stride};
} }
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
...@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd ...@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; }; template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; }; template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; }; template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
// clang-format on // clang-format on
// in byte // in byte
...@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd ...@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd
CK_TILE_HOST static std::string GetName() CK_TILE_HOST static std::string GetName()
{ {
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off // clang-format off
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
auto surfix = [&] () { auto surfix = [&] () {
std::string n; std::string n;
if (kFusedAdd != Rmsnorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Rmsnorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedQuant != Rmsnorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Rmsnorm2dFusedQuantEnumName<kFusedQuant>::name;
if (kPadN) n += "_pn"; if (kPadN) n += "_pn";
if (kSaveInvRms) n += "_rms"; if (kSaveInvRms) n += "_rms";
if (kTwoPass) n += "_2p"; if (kTwoPass) n += "_2p";
return n; }(); return n; }();
#define _SS_ std::string auto prec_str = [&] () {
#define _TS_ std::to_string std::string base_str = _SS_(t2s<XDataType>::name);
return _SS_("rmsnorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" + if (!std::is_same_v<XDataType, YDataType>) {
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
}
if (kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
if (kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) {
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
return base_str;
}();
return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix; _SS_(Pipeline::name) + surfix;
#undef _SS_
#undef _TS_
// clang-format on // clang-format on
#undef _SS_
#undef _TS_
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
...@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd ...@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.x_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd ...@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0}); tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}(); }();
const auto x_residual_window = [&]() {
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XResidualDataType*>(kargs.p_x_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.xr_stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
const auto gamma_window = [&]() { const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma), static_cast<const GammaDataType*>(kargs.p_gamma),
...@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd ...@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y), static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.y_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd ...@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0}); tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}(); }();
auto y_residual_window = [&]() {
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YResidualDataType*>(kargs.p_y_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.yr_stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
auto inv_rms_window = [&]() { auto inv_rms_window = [&]() {
if constexpr(kSaveInvRms) if constexpr(kSaveInvRms)
{ {
...@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd ...@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd
return make_null_tile_window(make_tuple(number<Block_M>{})); return make_null_tile_window(make_tuple(number<Block_M>{}));
}(); }();
auto sm_scale_window = [&]() {
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
make_tuple(kargs.n),
number<Vector_N>{});
return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}),
sequence<false>{}); // sm_scale no need pad
}();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_N>{}));
}
}();
auto y_scale_window = [&]() {
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<YScaleDataType*>(kargs.p_y_scale),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
tmp_0_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}));
}
}();
__shared__ char smem[GetSmemSize()]; __shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, Pipeline{}(x_window,
x_residual_window,
gamma_window, gamma_window,
y_window, y_window,
y_residual_window,
inv_rms_window, inv_rms_window,
sm_scale_window,
y_scale_window,
static_cast<const ComputeDataType>(kargs.epsilon), static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n, kargs.n,
smem); smem,
Epilogue{});
} }
}; };
......
...@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy ...@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{ {
using P_ = BlockReduce2dProblem<typename Problem::XDataType, using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
return BlockReduce2d<P_>{}; return BlockReduce2d<P_>{};
...@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy ...@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{ {
using P_ = BlockReduce2dProblem<typename Problem::XDataType, using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
return BlockReduce2dSync<P_>{}; return BlockReduce2dSync<P_>{};
...@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy ...@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{ {
using P_ = BlockReduce2dProblem<typename Problem::XDataType, using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{}; return BlockReduce2dCrossWarpSync<P_>{};
...@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy ...@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
{ {
if constexpr(Problem::kNeedCrossWarpSync) if constexpr(Problem::kNeedCrossWarpSync)
{ {
using P_ = BlockReduce2dProblem<typename Problem::XDataType, using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape>;
using block_reduce2d = BlockReduce2d<P_>; using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile = using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::XDataType>( decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>())); MakeXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>()); using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>; using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>; using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>; static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms; static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow> template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
YWindow& y_window, YWindow& y_window_,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window, InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window_,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem) const void* smem,
Epilogue) const
{ {
const auto x_window = const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window( const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
const auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{}; auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{}; auto reduce_sum_func = ReduceOp::Add{};
...@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
const auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
// load gamma (TODO: support no gamma?) // load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
}
}
// compute mean square each-thread->cross-lane->cross-warp // compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d( auto square_sum = block_reduce2d(acc,
x, reduce_square_sum_func.GetIdentityValue<ComputeDataType>(), reduce_square_sum_func); reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
block_reduce2d_sync(square_sum, reduce_sum_func); block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
...@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass ...@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms)); store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
// rmsnorm computation // rmsnorm computation
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution()); auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]); const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]); auto rmsn_ = acc[idx] * inv_rms_[i_idx] * gamma_;
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
y(idx) = type_convert<YDataType>(y_); rmsn(idx) = rmsn_;
}); });
store_tile(y_window, y);
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, rmsn);
}
} }
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -12,10 +12,10 @@ template <typename XDataType_, ...@@ -12,10 +12,10 @@ template <typename XDataType_,
typename ComputeDataType_, typename ComputeDataType_,
typename YDataType_, typename YDataType_,
typename InvRmsDataType_, typename InvRmsDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
typename BlockShape_, typename BlockShape_,
bool kPadN_, typename Traits_>
bool kSaveInvRms_,
bool kTwoPass_>
struct Rmsnorm2dFwdPipelineProblem struct Rmsnorm2dFwdPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
...@@ -23,14 +23,14 @@ struct Rmsnorm2dFwdPipelineProblem ...@@ -23,14 +23,14 @@ struct Rmsnorm2dFwdPipelineProblem
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>; using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_; using Traits = remove_cvref_t<Traits_>;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>; using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>; using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>; static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms; static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow> template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
YWindow& y_window, YWindow& y_window,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window, InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& /*sm_scale_window_*/,
YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem) const void* smem,
Epilogue) const
{ {
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window( auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
// Problem::BlockShape // Problem::BlockShape
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
...@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using XTensorType = decltype(load_tile(x_window)); using ComputeTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>(); auto square_sum = block_reduce2d.template MakeYBlockTile<ComputeTensorType>();
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>()); set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); auto x = load_tile(x_window);
block_reduce2d(x, square_sum, reduce_square_sum_func); auto x_resi = load_tile(x_residual_window);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N});
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_reduce2d(acc, square_sum, reduce_square_sum_func);
} }
block_reduce2d_sync(square_sum, reduce_sum_func); block_reduce2d_sync(square_sum, reduce_sum_func);
...@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window});
// rmsnorm computation // rmsnorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); auto x = load_tile(x_window);
// load gamma/beta (TODO: support no gamma/beta?) auto x_resi = load_tile(x_residual_window);
const auto gamma = load_tile(gamma_window); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
}
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution()); // load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { // rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]); const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]); auto rmsn_ = acc(idx) * inv_rms_[i_idx] * gamma_;
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
y(idx) = type_convert<YDataType>(y_); rmsn(idx) = rmsn_;
}); });
store_tile(y_window, y); static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
Epilogue{}(y_window, rmsn);
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N}); move_tile_window(y_window, {0, -Block_N});
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
enum class Rmsnorm2dFusedAddEnum
{
NO_ADD = 0,
// fused add before RMSNorm and store result to global
PRE_ADD_STORE = 1,
// fused add before RMSNorm, but not store result
PRE_ADD = 2,
};
// clang-format off
template<Rmsnorm2dFusedAddEnum> struct Rmsnorm2dFusedAddEnumName;
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
// clang-format on
enum class Rmsnorm2dFusedQuantEnum
{
NO_SWEEP = 0,
SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
};
// clang-format off
template<Rmsnorm2dFusedQuantEnum> struct Rmsnorm2dFusedQuantEnumName;
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
// clang-format on
template <bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_,
Rmsnorm2dFusedAddEnum kFusedAdd_,
Rmsnorm2dFusedQuantEnum kFusedQuant_>
struct Rmsnorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -12,7 +12,7 @@ namespace ck_tile { ...@@ -12,7 +12,7 @@ namespace ck_tile {
struct MoeSmoothquantHostArgs struct MoeSmoothquantHostArgs
{ {
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16 const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32 const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
const void* p_topk_ids; // [tokens, topk] const void* p_topk_ids; // [tokens, topk]
void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale
...@@ -34,7 +34,7 @@ struct MoeSmoothquant ...@@ -34,7 +34,7 @@ struct MoeSmoothquant
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = remove_cvref_t<typename Problem::QYDataType>; using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
...@@ -57,7 +57,7 @@ struct MoeSmoothquant ...@@ -57,7 +57,7 @@ struct MoeSmoothquant
struct Kargs struct Kargs
{ {
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16 const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32 const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
const void* p_topk_ids; // [tokens, topk] const void* p_topk_ids; // [tokens, topk]
void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale
...@@ -75,7 +75,7 @@ struct MoeSmoothquant ...@@ -75,7 +75,7 @@ struct MoeSmoothquant
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_xscale, hargs.p_smscale,
hargs.p_topk_ids, hargs.p_topk_ids,
hargs.p_yscale, hargs.p_yscale,
hargs.p_qy, hargs.p_qy,
...@@ -153,9 +153,10 @@ struct MoeSmoothquant ...@@ -153,9 +153,10 @@ struct MoeSmoothquant
}(); }();
// [experts, hidden_size], // [experts, hidden_size],
const auto xscale_window = [&]() { const auto smscale_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_xscale) + i_expert * kargs.hidden_size, static_cast<const SmoothScaleDataType*>(kargs.p_smscale) +
i_expert * kargs.hidden_size,
make_tuple(kargs.hidden_size), make_tuple(kargs.hidden_size),
make_tuple(1), make_tuple(1),
number<Vector_N>{}, number<Vector_N>{},
...@@ -198,7 +199,7 @@ struct MoeSmoothquant ...@@ -198,7 +199,7 @@ struct MoeSmoothquant
__shared__ char smem[GetSmemSize()]; __shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.hidden_size, smem); Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -12,10 +12,10 @@ namespace ck_tile { ...@@ -12,10 +12,10 @@ namespace ck_tile {
struct SmoothquantHostArgs struct SmoothquantHostArgs
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_xscale; // [1, n], input, columnwise scale, fp32 const void* p_smscale; // [1, n], input, columnwise scale, fp32
void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_xscale) void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_smscale)
void* p_qy; // [m, n], output, p_x * p_xscale / p_yscale void* p_qy; // [m, n], output, p_x * p_smscale / p_yscale
index_t m; index_t m;
index_t n; index_t n;
...@@ -31,7 +31,7 @@ struct Smoothquant ...@@ -31,7 +31,7 @@ struct Smoothquant
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = remove_cvref_t<typename Problem::QYDataType>; using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
...@@ -52,7 +52,7 @@ struct Smoothquant ...@@ -52,7 +52,7 @@ struct Smoothquant
struct Kargs struct Kargs
{ {
const void* p_x; const void* p_x;
const void* p_xscale; const void* p_smscale;
void* p_yscale; void* p_yscale;
void* p_qy; void* p_qy;
...@@ -67,7 +67,7 @@ struct Smoothquant ...@@ -67,7 +67,7 @@ struct Smoothquant
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_xscale, hargs.p_smscale,
hargs.p_yscale, hargs.p_yscale,
hargs.p_qy, hargs.p_qy,
hargs.m, hargs.m,
...@@ -134,9 +134,9 @@ struct Smoothquant ...@@ -134,9 +134,9 @@ struct Smoothquant
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0}); tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}(); }();
const auto xscale_window = [&]() { const auto smscale_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_xscale), static_cast<const SmoothScaleDataType*>(kargs.p_smscale),
make_tuple(kargs.n), make_tuple(kargs.n),
make_tuple(1), make_tuple(1),
number<Vector_N>{}, number<Vector_N>{},
...@@ -177,7 +177,7 @@ struct Smoothquant ...@@ -177,7 +177,7 @@ struct Smoothquant
__shared__ char smem[GetSmemSize()]; __shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.n, smem); Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.n, smem);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy ...@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXScaleBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeSmoothScaleBlockTileDistribution()
{ {
using S = typename Problem::BlockShape; using S = typename Problem::BlockShape;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -17,7 +17,7 @@ struct SmoothquantPipelineOnePass ...@@ -17,7 +17,7 @@ struct SmoothquantPipelineOnePass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>; using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
...@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass ...@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow> template <typename XWindow,
typename SmoothScaleWindow,
typename QYWindow,
typename YScaleWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XScaleWindow& xscale_window_, const SmoothScaleWindow& smscale_window_,
YScaleWindow& yscale_window, YScaleWindow& yscale_window,
QYWindow& qy_window, QYWindow& qy_window,
ck_tile::index_t, ck_tile::index_t,
...@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass ...@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass
{ {
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto xscale_window = make_tile_window( auto smscale_window = make_tile_window(
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>()); smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
auto reduce_absmax_func = ReduceOp::AbsMax{}; auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) { auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
...@@ -68,13 +71,13 @@ struct SmoothquantPipelineOnePass ...@@ -68,13 +71,13 @@ struct SmoothquantPipelineOnePass
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
const auto x = load_tile(x_window); const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window); const auto smscale = load_tile(smscale_window);
auto y = tile_elementwise_in( auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) { [&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b); return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
}, },
x, x,
xscale); smscale);
// compute absmax, cross-lane->cross-warp // compute absmax, cross-lane->cross-warp
auto absmax = [&]() { auto absmax = [&]() {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
namespace ck_tile { namespace ck_tile {
// Y = X * XScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale) // Y = X * SmoothScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
template <typename XDataType_, template <typename XDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename ComputeDataType_, typename ComputeDataType_,
typename YScaleDataType_, typename YScaleDataType_,
typename QYDataType_, typename QYDataType_,
...@@ -19,7 +19,7 @@ template <typename XDataType_, ...@@ -19,7 +19,7 @@ template <typename XDataType_,
struct SmoothquantPipelineProblem struct SmoothquantPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>; using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using QYDataType = remove_cvref_t<QYDataType_>; using QYDataType = remove_cvref_t<QYDataType_>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -17,7 +17,7 @@ struct SmoothquantPipelineTwoPass ...@@ -17,7 +17,7 @@ struct SmoothquantPipelineTwoPass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>; using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
...@@ -39,9 +39,12 @@ struct SmoothquantPipelineTwoPass ...@@ -39,9 +39,12 @@ struct SmoothquantPipelineTwoPass
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow> template <typename XWindow,
typename SmoothScaleWindow,
typename QYWindow,
typename YScaleWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XScaleWindow& xscale_window_, const SmoothScaleWindow& smscale_window_,
YScaleWindow& yscale_window, YScaleWindow& yscale_window,
QYWindow& qy_window, QYWindow& qy_window,
ck_tile::index_t row_size, ck_tile::index_t row_size,
...@@ -49,8 +52,8 @@ struct SmoothquantPipelineTwoPass ...@@ -49,8 +52,8 @@ struct SmoothquantPipelineTwoPass
{ {
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto xscale_window = make_tile_window( auto smscale_window = make_tile_window(
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>()); smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration = index_t num_n_tile_iteration =
...@@ -77,13 +80,13 @@ struct SmoothquantPipelineTwoPass ...@@ -77,13 +80,13 @@ struct SmoothquantPipelineTwoPass
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window); const auto smscale = load_tile(smscale_window);
const auto y = tile_elementwise_in( const auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) { [&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b); return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
}, },
x, x,
xscale); smscale);
constexpr auto x_size_per_row = constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{}); x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
...@@ -94,7 +97,7 @@ struct SmoothquantPipelineTwoPass ...@@ -94,7 +97,7 @@ struct SmoothquantPipelineTwoPass
block_reduce2d(y, absmax, reduce_absmax_func); block_reduce2d(y, absmax, reduce_absmax_func);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(xscale_window, {Block_N}); move_tile_window(smscale_window, {Block_N});
} }
// compute absmax, cross-lane->cross-warp // compute absmax, cross-lane->cross-warp
...@@ -114,20 +117,20 @@ struct SmoothquantPipelineTwoPass ...@@ -114,20 +117,20 @@ struct SmoothquantPipelineTwoPass
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(xscale_window, {-Block_N}); move_tile_window(smscale_window, {-Block_N});
move_tile_window(qy_window, {0, stride_to_right_most_window}); move_tile_window(qy_window, {0, stride_to_right_most_window});
// recompute y and quantize y to qy // recompute y and quantize y to qy
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window); const auto smscale = load_tile(smscale_window);
const auto y = tile_elementwise_in( const auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) { [&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b); return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
}, },
x, x,
xscale); smscale);
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution()); auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
sweep_tile(qy, [&](auto idx) { sweep_tile(qy, [&](auto idx) {
...@@ -138,7 +141,7 @@ struct SmoothquantPipelineTwoPass ...@@ -138,7 +141,7 @@ struct SmoothquantPipelineTwoPass
store_tile(qy_window, qy); store_tile(qy_window, qy);
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(xscale_window, {0, -Block_N}); move_tile_window(smscale_window, {0, -Block_N});
move_tile_window(qy_window, {0, -Block_N}); move_tile_window(qy_window, {0, -Block_N});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment