Unverified Commit 04dd3148 authored by ruanjm's avatar ruanjm Committed by GitHub
Browse files

[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)



* Add shortcut to RMSNorm

* Modify test for adding shortcut for RMSNorm

* Add fused parameter into tests

* 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp

* 1. Supports various stride and percisions.

* Add support of Epilogue

* Add fuse and epilogue support to rmsnorm ref

* Modify rmsnorm example

* Refactor tests/examples

* Bug fix for newly added tests/examples

* Bug fix for new tests 2

* Modify smoke test scripts

remove dbg code

* Supports non-smooth dyanmic quant

* Update Rmsnorm2dFwd::GetName()

* rename xscale and prec_sx to smoothscale and prec_sm

Bug fix after rename

Remove files

* change example_rmsnorm2d_fwd.cpp

* update performance calculator

* Fix issue in two-pass when fuse add is enabled

* Remove comment of beta

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
parent c0b90f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = rmsnorm2d_fwd_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
using trait_ = rmsnorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_>;
template <typename Traits_>
float rmsnorm2d_fwd_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem =
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<DataType>::XDataType,
typename RmsnormTypeConfig<DataType>::GammaDataType,
typename RmsnormTypeConfig<DataType>::ComputeDataType,
typename RmsnormTypeConfig<DataType>::YDataType,
typename RmsnormTypeConfig<DataType>::InvRmsDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveInvRms,
Traits_::kTwoPass>;
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
...@@ -19,17 +19,37 @@ auto get_elimit<ck_tile::bf16_t>() ...@@ -19,17 +19,37 @@ auto get_elimit<ck_tile::bf16_t>()
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template <>
auto get_elimit<ck_tile::int8_t>()
{
double rtol = 1e-02;
double atol = 1.0;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case") .insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision") .insert("prec_i", "fp16", "input precision")
.insert("prec_o", "auto", "output precision, set auto will be the same as input")
.insert("prec_sm",
"auto",
"output quant scale type, set auto will use fp32. used when fquant=1")
.insert("prec_sy",
"auto",
"output quant scale type, set auto will use fp32. used when fquant=1 or 2")
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter"); .insert("repeat", "20", "hot iter");
...@@ -37,28 +57,68 @@ auto create_args(int argc, char* argv[]) ...@@ -37,28 +57,68 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
template <typename DataType, bool SaveRms> template <typename InDataType,
typename OutDataType,
typename SmoothScaleDataType,
typename YScaleDataType,
bool SaveRms>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); float epsilon = arg_parser.get_float("e");
if(stride < 0) int kname = arg_parser.get_int("kname");
stride = n; int do_validation = arg_parser.get_int("v");
float epsilon = arg_parser.get_float("e"); int fused_add = arg_parser.get_int("fadd");
std::string data_type = arg_parser.get_str("prec"); int fused_quant = arg_parser.get_int("fquant");
int kname = arg_parser.get_int("kname"); int warmup = arg_parser.get_int("warmup");
int do_validation = arg_parser.get_int("v"); int repeat = arg_parser.get_int("repeat");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
assert(stride >= n); x_stride = n;
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
using TypeConfig = RmsnormTypeConfig<DataType>; if(xr_stride < 0)
xr_stride = n;
using XDataType = typename TypeConfig::XDataType; ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
using YDataType = typename TypeConfig::YDataType; if(y_stride < 0)
using GammaDataType = typename TypeConfig::GammaDataType; y_stride = n;
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
if(yr_stride < 0)
yr_stride = n;
assert(x_stride >= n);
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sm = arg_parser.get_str("prec_sm");
std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto")
{
prec_o = prec_i;
}
if(prec_sm == "auto")
{
prec_sm = "fp32";
}
if(prec_sy == "auto")
{
prec_sy = "fp32";
}
if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8")
{
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl;
return false;
}
using TypeConfig =
RmsnormTypeConfig<InDataType, OutDataType, SmoothScaleDataType, YScaleDataType>;
using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType;
using GammaDataType = typename TypeConfig::GammaDataType;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
using InvRmsDataType = using InvRmsDataType =
std::conditional_t<SaveRms, typename TypeConfig::InvRmsDataType, ck_tile::null_type>; std::conditional_t<SaveRms, typename TypeConfig::InvRmsDataType, ck_tile::null_type>;
...@@ -66,43 +126,84 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -66,43 +126,84 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host({n});
ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host_dev({n});
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m}); ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<SmoothScaleDataType>{-1.f, 1.f}(sm_scale_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
x_residual_buf.ToDevice(x_residual_host.data());
sm_scale_buf.ToDevice(sm_scale_host.data());
auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_o)
{
base_str += "|" + prec_o;
}
if(fused_quant == 1)
{
base_str += std::string("(") + prec_sy + ")";
}
return base_str;
}();
std::cout << "[" << data_type << "]" std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << std::flush;
rmsnorm2d_fwd_traits traits{data_type, SaveRms}; rmsnorm2d_fwd_traits traits{prec_i, prec_o, prec_sm, prec_sy, SaveRms, fused_add, fused_quant};
rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(), rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr,
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(),
nullptr, fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
nullptr, // p_invRms, unsupported yet
epsilon, epsilon,
m, m,
n, n,
stride}; x_stride, // x row_stride
xr_stride, // x residule row stride
y_stride, // y row stride
yr_stride}; // y residule row stride
float ave_time = rmsnorm2d_fwd( float ave_time = rmsnorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = std::size_t num_byte =
sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n; sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n;
num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0;
num_byte += fused_add ? sizeof(XResidualDataType) * m * n : 0;
num_byte += ((fused_quant == 1) || (fused_quant == 2)) ? sizeof(YScaleDataType) * m : 0;
num_byte += (fused_quant == 1) ? sizeof(SmoothScaleDataType) * n : 0;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
...@@ -112,38 +213,131 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -112,38 +213,131 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
// reference // reference
ck_tile::reference_rmsnorm2d_fwd<XDataType, if(fused_add != 0)
GammaDataType, {
ComputeDataType, // fused pre_add/pre_add_store
YDataType, // TODO we accumulate directly to x_host for simplcity here...
InvRmsDataType>( std::transform(x_host.mData.cbegin(),
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); x_host.mData.cend(),
x_residual_host.mData.cbegin(),
x_host.mData.begin(),
[](auto x_, auto r_) {
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
ck_tile::type_convert<ComputeDataType>(r_);
return ck_tile::type_convert<XDataType>(o_);
});
}
if(fused_quant != 0)
{
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
int N_ = acc_.mDesc.get_lengths()[1];
if(fused_quant == 1)
{
for(int n_ = 0; n_ < N_; n_++)
{
// input smooth outlier
acc_(m_, n_) = acc_(m_, n_) *
ck_tile::type_convert<ComputeDataType>(sm_scale_host(n_));
}
}
ComputeDataType absmax = static_cast<ComputeDataType>(0);
for(int n_ = 0; n_ < N_; n_++)
{
const auto a = ck_tile::abs(acc_(m_, n_));
absmax = a > absmax ? a : absmax;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++)
{
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
}
};
ck_tile::reference_rmsnorm2d_fwd<XDataType,
GammaDataType,
ComputeDataType,
YDataType,
InvRmsDataType>(
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon, dquant_functor);
}
else
{
ck_tile::reference_rmsnorm2d_fwd<XDataType,
GammaDataType,
ComputeDataType,
YDataType,
InvRmsDataType>(
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon);
}
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>(); ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
if(stride == n) if(fused_add == 1)
{
y_residual_buf.FromDevice(y_residual_host_dev.data());
}
auto [rtol, atol] = get_elimit<YDataType>();
if(x_stride == n)
{ {
pass = ck_tile::check_err( pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); y_host_dev, y_host_ref, std::string("\nOUT Error: Incorrect results!"), rtol, atol);
if(fused_add == 1)
{
pass &= ck_tile::check_err(y_residual_host_dev,
x_host,
std::string("\nADD Error: Incorrect results!"),
rtol,
atol);
}
} }
else else
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride, std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
y_host_dev.begin() + i_r * stride + n); y_host_dev.begin() + i_r * y_stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride, std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
y_host_ref.begin() + i_r * stride + n); y_host_ref.begin() + i_r * y_stride + n);
pass &= ck_tile::check_err(y_host_dev_row, pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row, y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) + std::string("\nOUT[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"), std::string("] Error: Incorrect results!"),
rtol, rtol,
atol); atol);
if(fused_add == 1)
{
std::vector<YResidualDataType> y_residual_host_dev_row(
y_residual_host_dev.begin() + i_r * yr_stride,
y_residual_host_dev.begin() + i_r * yr_stride + n);
std::vector<YResidualDataType> y_residual_host_ref_row(
x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
pass &= ck_tile::check_err(y_residual_host_dev_row,
y_residual_host_ref_row,
std::string("\nADD[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
} }
} }
if(fused_quant == 1)
{
y_scale_buf.FromDevice(y_scale_host_dev.data());
pass &= ck_tile::check_err(y_scale_host_dev,
y_scale_host_ref,
std::string("\nSCALE Error: Incorrect results!"),
rtol,
atol);
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
} }
...@@ -156,23 +350,55 @@ int main(int argc, char* argv[]) ...@@ -156,23 +350,55 @@ int main(int argc, char* argv[])
if(!result) if(!result)
return -1; return -1;
const std::string data_type = arg_parser.get_str("prec"); std::string prec_i = arg_parser.get_str("prec_i");
int save_rms = arg_parser.get_int("save_rms"); std::string prec_o = arg_parser.get_str("prec_o");
if(data_type == "fp16" && save_rms) std::string prec_sm = arg_parser.get_str("prec_sm");
std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto")
{
prec_o = prec_i;
}
if(prec_sm == "auto")
{
prec_sm = "fp32";
}
if(prec_sy == "auto")
{
prec_sy = "fp32";
}
int save_rms = arg_parser.get_int("save_rms");
if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms)
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_rms)
{ {
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(data_type == "fp16" && !save_rms) else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
save_rms)
{ {
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16" && save_rms) else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_rms)
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
// dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_rms)
{ {
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16" && !save_rms) else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_rms)
{ {
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2;
} }
return -3; return -3;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -8,27 +8,34 @@ ...@@ -8,27 +8,34 @@
#include "ck_tile/ops/rmsnorm2d.hpp" #include "ck_tile/ops/rmsnorm2d.hpp"
#include <string> #include <string>
template <typename DataType> template <typename InType,
typename OutType,
typename SmoothScaleDataType_,
typename YScaleDataType_>
struct RmsnormTypeConfig; struct RmsnormTypeConfig;
template <> template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
struct RmsnormTypeConfig<ck_tile::half_t> struct RmsnormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t; using YDataType = OutType;
using GammaDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t;
using InvRmsDataType = ck_tile::half_t; using InvRmsDataType = ck_tile::half_t;
using ComputeDataType = float; using ComputeDataType = float;
using SmoothScaleDataType = SmoothScaleDataType_;
using YScaleDataType = YScaleDataType_;
}; };
template <> template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
struct RmsnormTypeConfig<ck_tile::bf16_t> struct RmsnormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = ck_tile::bf16_t; using YDataType = OutType;
using GammaDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t;
using InvRmsDataType = ck_tile::bf16_t; using InvRmsDataType = ck_tile::bf16_t;
using ComputeDataType = float; using ComputeDataType = float;
using SmoothScaleDataType = SmoothScaleDataType_;
using YScaleDataType = YScaleDataType_;
}; };
// runtime args // runtime args
...@@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs ...@@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs
{ {
}; };
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
struct rmsnorm2d_fwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_> template <typename Traits_>
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct rmsnorm2d_fwd_traits struct rmsnorm2d_fwd_traits
{ {
std::string data_type; std::string prec_i; // input precision
std::string prec_o; // output precision
// if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check)
std::string prec_sm; // x-scale, used for [1*N] input smooth quant
std::string prec_sy; // y-scale, used for [M*1] output for next layer
bool save_rms; bool save_rms;
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
}; };
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&); float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);
#!/bin/sh #!/bin/sh
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8"; do
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13 for fadd in "0" "1"; do
$EXE -prec=$pr_i -m=17 -n=16 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
$EXE -prec=$pr_i -m=1 -n=100 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
$EXE -prec=$pr_i -m=4 -n=128 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
$EXE -prec=$pr_i -m=80 -n=127 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
$EXE -prec=$pr_i -m=22 -n=255 -stride=256 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
$EXE -prec=$pr_i -m=7 -n=599 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
$EXE -prec=$pr_i -m=19 -n=512 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
$EXE -prec=$pr_i -m=11 -n=510 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
$EXE -prec=$pr_i -m=171 -n=676 -stride=818 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
$EXE -prec=$pr_i -m=91 -n=636 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
$EXE -prec=$pr_i -m=12 -n=768 -stride=800 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
$EXE -prec=$pr_i -m=100 -n=766 -stride=812 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
$EXE -prec=$pr_i -m=31 -n=1024 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
$EXE -prec=$pr_i -m=8 -n=1501 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
$EXE -prec=$pr_i -m=3 -n=1826 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
$EXE -prec=$pr_i -m=5 -n=2040 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
$EXE -prec=$pr_i -m=7 -n=2734 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
$EXE -prec=$pr_i -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec=$pr_i -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec=$pr_i -m=3 -n=8192 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec=$pr_i -m=1 -n=10547 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
$EXE -prec=$pr_i -m=3 -n=17134 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done done
...@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
assert(stride >= n); assert(x_stride >= n);
using XDataType = DataType; using XDataType = DataType;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
...@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data()); smscale_buf.ToDevice(smscale_host.data());
constexpr bool kTwoPass = true; constexpr bool kTwoPass = true;
...@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>; using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::SmoothquantPipelineProblem<XDataType, using Problem = ck_tile::SmoothquantPipelineProblem<XDataType,
XScaleDataType, SmoothScaleDataType,
ComputeDataType, ComputeDataType,
YScaleDataType, YScaleDataType,
QYDataType, QYDataType,
...@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Kernel = ck_tile::Smoothquant<Pipeline>; using Kernel = ck_tile::Smoothquant<Pipeline>;
ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(), ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
...@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_)); auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_) for(int m_ = 0; m_ < m; ++m_)
{ {
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale; y_host(m_, n_) = v_x * v_smscale;
} }
}; };
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "smoothquant.hpp" #include "smoothquant.hpp"
...@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a) ...@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a)
using PipelineProblem = ck_tile::SmoothquantPipelineProblem< using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
typename SmoothquantTypeConfig<DataType>::XDataType, typename SmoothquantTypeConfig<DataType>::XDataType,
typename SmoothquantTypeConfig<DataType>::XScaleDataType, typename SmoothquantTypeConfig<DataType>::SmoothScaleDataType,
typename SmoothquantTypeConfig<DataType>::ComputeDataType, typename SmoothquantTypeConfig<DataType>::ComputeDataType,
typename SmoothquantTypeConfig<DataType>::YScaleDataType, typename SmoothquantTypeConfig<DataType>::YScaleDataType,
typename SmoothquantTypeConfig<DataType>::QYDataType, typename SmoothquantTypeConfig<DataType>::QYDataType,
......
...@@ -66,15 +66,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -66,15 +66,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
using TypeConfig = SmoothquantTypeConfig<DataType>; using TypeConfig = SmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using XScaleDataType = typename TypeConfig::XScaleDataType; using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType; using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType; using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
...@@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data()); smscale_buf.ToDevice(smscale_host.data());
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
...@@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
smoothquant_traits traits{data_type}; smoothquant_traits traits{data_type};
smoothquant_args args{x_buf.GetDeviceBuffer(), smoothquant_args args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
...@@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = smoothquant( float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n +
sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n; sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
...@@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_)); auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_) for(int m_ = 0; m_ < m; ++m_)
{ {
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale; y_host(m_, n_) = v_x * v_smscale;
} }
}; };
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,21 +14,21 @@ struct SmoothquantTypeConfig; ...@@ -14,21 +14,21 @@ struct SmoothquantTypeConfig;
template <> template <>
struct SmoothquantTypeConfig<ck_tile::half_t> struct SmoothquantTypeConfig<ck_tile::half_t>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
template <> template <>
struct SmoothquantTypeConfig<ck_tile::bf16_t> struct SmoothquantTypeConfig<ck_tile::bf16_t>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
// runtime args // runtime args
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "moe_smoothquant.hpp" #include "moe_smoothquant.hpp"
...@@ -35,7 +35,7 @@ float moe_smoothquant_(const S& s, A a) ...@@ -35,7 +35,7 @@ float moe_smoothquant_(const S& s, A a)
using PipelineProblem = ck_tile::SmoothquantPipelineProblem< using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
typename MoeSmoothquantTypeConfig<DataType>::XDataType, typename MoeSmoothquantTypeConfig<DataType>::XDataType,
typename MoeSmoothquantTypeConfig<DataType>::XScaleDataType, typename MoeSmoothquantTypeConfig<DataType>::SmoothScaleDataType,
typename MoeSmoothquantTypeConfig<DataType>::ComputeDataType, typename MoeSmoothquantTypeConfig<DataType>::ComputeDataType,
typename MoeSmoothquantTypeConfig<DataType>::YScaleDataType, typename MoeSmoothquantTypeConfig<DataType>::YScaleDataType,
typename MoeSmoothquantTypeConfig<DataType>::QYDataType, typename MoeSmoothquantTypeConfig<DataType>::QYDataType,
......
...@@ -91,15 +91,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -91,15 +91,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
using TypeConfig = MoeSmoothquantTypeConfig<DataType>; using TypeConfig = MoeSmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using XScaleDataType = typename TypeConfig::XScaleDataType; using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType; using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType; using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({experts * hidden_size}); ck_tile::HostTensor<SmoothScaleDataType> smscale_host({experts * hidden_size});
ck_tile::HostTensor<ck_tile::index_t> topk_ids_host({tokens, topk}); ck_tile::HostTensor<ck_tile::index_t> topk_ids_host({tokens, topk});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({topk * tokens}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({topk * tokens}, {1});
...@@ -110,16 +110,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -110,16 +110,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
topid_unique_gen<ck_tile::index_t>(topk_ids_host.mData, tokens, topk, experts, 11937); topid_unique_gen<ck_tile::index_t>(topk_ids_host.mData, tokens, topk, experts, 11937);
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data()); smscale_buf.ToDevice(smscale_host.data());
topk_ids_buf.ToDevice(topk_ids_host.data()); topk_ids_buf.ToDevice(topk_ids_host.data());
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
...@@ -129,7 +129,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -129,7 +129,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
moe_smoothquant_traits traits{data_type}; moe_smoothquant_traits traits{data_type};
moe_smoothquant_args args{x_buf.GetDeviceBuffer(), moe_smoothquant_args args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), smscale_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(), topk_ids_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
...@@ -143,9 +143,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -143,9 +143,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = moe_smoothquant( float ave_time = moe_smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = std::size_t num_byte = sizeof(XDataType) * tokens * hidden_size +
sizeof(XDataType) * tokens * hidden_size + sizeof(XScaleDataType) * topk * hidden_size + sizeof(SmoothScaleDataType) * topk * hidden_size +
sizeof(YScaleDataType) * topk * tokens + sizeof(QYDataType) * topk * tokens * hidden_size; sizeof(YScaleDataType) * topk * tokens +
sizeof(QYDataType) * topk * tokens * hidden_size;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
...@@ -165,11 +166,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -165,11 +166,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(int i_h = 0; i_h < hidden_size; ++i_h) for(int i_h = 0; i_h < hidden_size; ++i_h)
{ {
auto v_xscale = ck_tile::type_convert<ComputeDataType>( auto v_smscale = ck_tile::type_convert<ComputeDataType>(
xscale_host(i_expert * hidden_size + i_h)); smscale_host(i_expert * hidden_size + i_h));
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(i_token, i_h)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(i_token, i_h));
// y_host(i_token * topk + i_topk, i_h) = v_x * v_xscale; // y_host(i_token * topk + i_topk, i_h) = v_x * v_smscale;
y_host(i_topk * tokens + i_token, i_h) = v_x * v_xscale; y_host(i_topk * tokens + i_token, i_h) = v_x * v_smscale;
} }
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,21 +14,21 @@ struct MoeSmoothquantTypeConfig; ...@@ -14,21 +14,21 @@ struct MoeSmoothquantTypeConfig;
template <> template <>
struct MoeSmoothquantTypeConfig<ck_tile::half_t> struct MoeSmoothquantTypeConfig<ck_tile::half_t>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
template <> template <>
struct MoeSmoothquantTypeConfig<ck_tile::bf16_t> struct MoeSmoothquantTypeConfig<ck_tile::bf16_t>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
// runtime args // runtime args
......
...@@ -8,16 +8,40 @@ ...@@ -8,16 +8,40 @@
namespace ck_tile { namespace ck_tile {
// Note: for simplicity, each functor only care about single M
struct reference_rmsnorm2d_default_epilogue
{
template <typename OutDataType, typename AccDataType>
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
{
const int N = acc.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
}
}
template <typename OutDataType, typename AccDataType>
auto operator()(int m, const HostTensor<AccDataType>& acc)
{
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
operator()(m, o, acc);
return o;
}
};
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename InvRmsDataType> typename InvRmsDataType,
typename Epilogue = reference_rmsnorm2d_default_epilogue>
void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n, void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n, const HostTensor<GammaDataType>& gamma_n,
HostTensor<YDataType>& y_m_n, HostTensor<YDataType>& y_m_n,
HostTensor<InvRmsDataType>& invRms_m, HostTensor<InvRmsDataType>& invRms_m,
ComputeDataType epsilon) ComputeDataType epsilon,
Epilogue epilogue_functor = {})
{ {
auto rmsnorm2d_fwd_func = [&](auto m) { auto rmsnorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1]; const int N = x_m_n.mDesc.get_lengths()[1];
...@@ -37,13 +61,15 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n, ...@@ -37,13 +61,15 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>) if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor); invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n)); ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n)); ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
auto y = x * divisor * gamma; acc(m, n) = x * divisor * gamma;
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y);
} }
epilogue_functor(m, y_m_n, acc);
}; };
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])( make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment