Commit 35ba0864 authored by aska-0096's avatar aska-0096
Browse files

fp8 add_rmsnorm_dynamic_dequant

parent 487a05d6
...@@ -6,8 +6,12 @@ ...@@ -6,8 +6,12 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd x 3p // rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,7 +6,10 @@ ...@@ -6,7 +6,10 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd x 3p // rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,7 +6,10 @@ ...@@ -6,7 +6,10 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd x 3p // rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on // clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, false, false, false>>(const S&, A);
// clang-format on
...@@ -6,9 +6,12 @@ ...@@ -6,9 +6,12 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd x 3p // rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, true, true>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, true, true>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, true, true>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, true, true>>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
// clang-format on // clang-format on
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
using S = ck_tile::stream_config; using S = ck_tile::stream_config;
using A = add_rmsnorm2d_rdquant_fwd_args; using A = add_rmsnorm2d_rdquant_fwd_args;
template <typename DataType_, template <typename InputDataType_,
typename QuantizedDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M ck_tile::index_t ThreadPerBlock_M_, // num threads along M
...@@ -20,7 +21,8 @@ template <typename DataType_, ...@@ -20,7 +21,8 @@ template <typename DataType_,
bool kPadN_, bool kPadN_,
bool kSaveInvRms_, bool kSaveInvRms_,
bool kTwoPass_> bool kTwoPass_>
using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<DataType_, using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<InputDataType_,
QuantizedDataType_,
Repeat_M_, Repeat_M_,
Repeat_N_, Repeat_N_,
ThreadPerBlock_M_, ThreadPerBlock_M_,
...@@ -33,16 +35,17 @@ using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<DataType_, ...@@ -33,16 +35,17 @@ using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<DataType_,
template <typename Traits_> template <typename Traits_>
float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
{ {
using DataType = typename Traits_::DataType; using InputDataType = typename Traits_::InputDataType;
using QuantizedDataType = typename Traits_::QuantizedDataType;
using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem< using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<
typename AddRmsnormRdquantTypeConfig<DataType>::ADataType, typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::ADataType,
typename AddRmsnormRdquantTypeConfig<DataType>::BDataType, typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::BDataType,
typename AddRmsnormRdquantTypeConfig<DataType>::GammaDataType, typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::GammaDataType,
typename AddRmsnormRdquantTypeConfig<DataType>::ComputeDataType, typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::ComputeDataType,
typename AddRmsnormRdquantTypeConfig<DataType>::XDataType, typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::XDataType,
typename AddRmsnormRdquantTypeConfig<DataType>::YScaleDataType, typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::YScaleDataType,
typename AddRmsnormRdquantTypeConfig<DataType>::QYDataType, typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::QYDataType,
typename Traits_::Shape, typename Traits_::Shape,
Traits_::kPadN, Traits_::kPadN,
Traits_::kSaveX, Traits_::kSaveX,
......
...@@ -22,7 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>& ...@@ -22,7 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>&
// scale = amax / 127 for int8 // scale = amax / 127 for int8
auto v_scale = type_convert<XDataType>(scale_m(m)); auto v_scale = type_convert<XDataType>(scale_m(m));
auto v_qx = v_x / v_scale; auto v_qx = v_x / v_scale;
qx_m_n(m, n) = saturates<QXDataType>{}(v_qx); qx_m_n(m, n) = type_convert<QXDataType>(saturates<QXDataType>{}(v_qx));
} }
}; };
......
...@@ -157,7 +157,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass ...@@ -157,7 +157,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
sweep_tile(qy, [&, yscale_ = yscale](auto idx) { sweep_tile(qy, [&, yscale_ = yscale](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
auto qy_ = y[idx] / yscale_[i_idx]; auto qy_ = y[idx] / yscale_[i_idx];
qy(idx) = saturates<QYDataType>{}(qy_); qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
}); });
store_tile(qy_window, qy); store_tile(qy_window, qy);
} }
......
...@@ -260,7 +260,7 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass ...@@ -260,7 +260,7 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
const auto x_ = type_convert<ComputeDataType>(x[idx]); const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = x_ * inv_rms[i_idx] * gamma_; auto y_ = x_ * inv_rms[i_idx] * gamma_;
auto qy_ = y_ / yscale[i_idx]; auto qy_ = y_ / yscale[i_idx];
qy(idx) = saturates<QYDataType>{}(qy_); qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
}); });
store_tile(qy_window, qy); store_tile(qy_window, qy);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment