Unverified Commit 04dd3148 authored by ruanjm's avatar ruanjm Committed by GitHub
Browse files

[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)



* Add shortcut to RMSNorm

* Modify test for adding shortcut for RMSNorm

* Add fused parameter into tests

* 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp

* 1. Supports various stride and percisions.

* Add support of Epilogue

* Add fuse and epilogue support to rmsnorm ref

* Modify rmsnorm example

* Refactor tests/examples

* Bug fix for newly added tests/examples

* Bug fix for new tests 2

* Modify smoke test scripts

remove dbg code

* Supports non-smooth dyanmic quant

* Update Rmsnorm2dFwd::GetName()

* rename xscale and prec_sx to smoothscale and prec_sm

Bug fix after rename

Remove files

* change example_rmsnorm2d_fwd.cpp

* update performance calculator

* Fix issue in two-pass when fuse add is enabled

* Remove comment of beta

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
parent c0b90f13
...@@ -59,7 +59,7 @@ args: ...@@ -59,7 +59,7 @@ args:
-kname print kernel name or not (default:1) -kname print kernel name or not (default:1)
-prec_i input precision (default:fp16) -prec_i input precision (default:fp16)
-prec_o output precision, set auto will be the same as input (default:auto) -prec_o output precision, set auto will be the same as input (default:auto)
-prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) -prec_sm output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
-prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto) -prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto)
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0) -fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
...@@ -69,7 +69,7 @@ args: ...@@ -69,7 +69,7 @@ args:
``` ```
## limitations ## limitations
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose. Note that `fquant=2`, `fadd=2`, `prec_sm/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose.
``` ```
# some case # some case
......
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation # generate kernel instances to speed up compilation
import argparse import argparse
...@@ -52,7 +52,7 @@ class layernorm_fwd_codegen: ...@@ -52,7 +52,7 @@ class layernorm_fwd_codegen:
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename XDataType_, template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N ck_tile::index_t Repeat_N_, // each thread repeat along N
...@@ -71,7 +71,7 @@ struct layernorm2d_fwd_traits_ ...@@ -71,7 +71,7 @@ struct layernorm2d_fwd_traits_
{ {
using XDataType = ck_tile::remove_cvref_t<XDataType_>; using XDataType = ck_tile::remove_cvref_t<XDataType_>;
using YDataType = ck_tile::remove_cvref_t<YDataType_>; using YDataType = ck_tile::remove_cvref_t<YDataType_>;
using XScaleDataType = ck_tile::remove_cvref_t<XScaleDataType_>; using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>; using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
...@@ -135,7 +135,7 @@ struct layernorm2d_fwd_traits_ ...@@ -135,7 +135,7 @@ struct layernorm2d_fwd_traits_
template <typename XDataType_, template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N ck_tile::index_t Repeat_N_, // each thread repeat along N
...@@ -152,7 +152,7 @@ template <typename XDataType_, ...@@ -152,7 +152,7 @@ template <typename XDataType_,
int kFusedQuant_> int kFusedQuant_>
using traits_ = layernorm2d_fwd_traits_<XDataType_, using traits_ = layernorm2d_fwd_traits_<XDataType_,
YDataType_, YDataType_,
XScaleDataType_, SmoothScaleDataType_,
YScaleDataType_, YScaleDataType_,
Repeat_M_, Repeat_M_,
Repeat_N_, Repeat_N_,
...@@ -170,7 +170,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_, ...@@ -170,7 +170,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
""" """
API_COMMON_HEADER = """ API_COMMON_HEADER = """
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
...@@ -189,9 +189,9 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -189,9 +189,9 @@ float layernorm2d_fwd_(const S& s, A a)
{{ {{
using XDataType = typename Traits_::XDataType; using XDataType = typename Traits_::XDataType;
using YDataType = typename Traits_::YDataType; using YDataType = typename Traits_::YDataType;
using XScaleDataType = typename Traits_::XScaleDataType; using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
using YScaleDataType = typename Traits_::YScaleDataType; using YScaleDataType = typename Traits_::YScaleDataType;
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType; using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN, using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd, Traits_::kSaveMeanInvStd,
...@@ -202,16 +202,16 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -202,16 +202,16 @@ float layernorm2d_fwd_(const S& s, A a)
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd), static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>; static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XBiasDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XBiasDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::MeanDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::MeanDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::InvStdDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvStdDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XScaleDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YScaleDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
typename Traits_::Shape, typename Traits_::Shape,
PipelineTraits>; PipelineTraits>;
...@@ -224,7 +224,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -224,7 +224,7 @@ float layernorm2d_fwd_(const S& s, A a)
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
static constexpr bool UseRawStore = sizeof(YDataType) == 4; static constexpr bool UseRawStore = sizeof(YDataType) == 4;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape, using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>; ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
...@@ -249,7 +249,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -249,7 +249,7 @@ float layernorm2d_fwd_(const S& s, A a)
API_BASE = """ API_BASE = """
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
...@@ -285,7 +285,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -285,7 +285,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
INSTANCE_BASE = """ INSTANCE_BASE = """
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_api_common.hpp" #include "layernorm2d_fwd_api_common.hpp"
...@@ -374,7 +374,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -374,7 +374,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
class h_traits: class h_traits:
F_XDataType : str F_XDataType : str
F_YDataType : str F_YDataType : str
F_XScaleDataType : str F_SmoothScaleDataType : str
F_YScaleDataType : str F_YScaleDataType : str
F_Repeat_M : int F_Repeat_M : int
F_Repeat_N : int F_Repeat_N : int
...@@ -392,7 +392,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -392,7 +392,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@property @property
def trait_name(self) ->str: def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_ return t_
...@@ -477,8 +477,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -477,8 +477,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
if ins.F_kFusedQuant == 0: if ins.F_kFusedQuant == 0:
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
elif ins.F_kFusedQuant == 1: elif ins.F_kFusedQuant == 1:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType) f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType)
elif ins.F_kFusedQuant == 2: elif ins.F_kFusedQuant == 2:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
...@@ -572,7 +572,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -572,7 +572,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list): for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list):
prec_i, prec_o = dtype.split(',') prec_i, prec_o = dtype.split(',')
scale_x, scale_y = scale_type.split(',') scale_sm, scale_y = scale_type.split(',')
if prec_o in dynamic_quant_out_dtype and fused_quant != 1: if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
continue # skip non dynamic quant case continue # skip non dynamic quant case
if fused_quant == 1 and hs_key == 'big': if fused_quant == 1 and hs_key == 'big':
...@@ -582,8 +582,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -582,8 +582,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_ = copy.copy(chs_) # copy the base instance out h_ = copy.copy(chs_) # copy the base instance out
h_.F_XDataType = prec_i h_.F_XDataType = prec_i
h_.F_YDataType = prec_o h_.F_YDataType = prec_o
h_.F_XScaleDataType = scale_y h_.F_SmoothScaleDataType = scale_sm
h_.F_YScaleDataType = scale_x h_.F_YScaleDataType = scale_y
h_.F_kXbias = xbias h_.F_kXbias = xbias
h_.F_kFusedAdd = fused_add h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant h_.F_kFusedQuant = fused_quant
......
...@@ -35,7 +35,7 @@ auto create_args(int argc, char* argv[]) ...@@ -35,7 +35,7 @@ auto create_args(int argc, char* argv[])
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec_i", "fp16", "input precision") .insert("prec_i", "fp16", "input precision")
.insert("prec_o", "auto", "output precision, set auto will be the same as input") .insert("prec_o", "auto", "output precision, set auto will be the same as input")
.insert("prec_sx", .insert("prec_sm",
"auto", "auto",
"output quant scale type, set auto will use fp32. used when fquant=1") "output quant scale type, set auto will use fp32. used when fquant=1")
.insert("prec_sy", .insert("prec_sy",
...@@ -53,7 +53,7 @@ auto create_args(int argc, char* argv[]) ...@@ -53,7 +53,7 @@ auto create_args(int argc, char* argv[])
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename XScaleDataType, typename SmoothScaleDataType,
typename YScaleDataType, typename YScaleDataType,
bool SaveMeanVar> bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
...@@ -75,15 +75,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -75,15 +75,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx"); std::string prec_sm = arg_parser.get_str("prec_sm");
std::string prec_sy = arg_parser.get_str("prec_sy"); std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto") if(prec_o == "auto")
{ {
prec_o = prec_i; prec_o = prec_i;
} }
if(prec_sx == "auto") if(prec_sm == "auto")
{ {
prec_sx = "fp32"; prec_sm = "fp32";
} }
if(prec_sy == "auto") if(prec_sy == "auto")
{ {
...@@ -105,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -105,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
assert(x_stride >= n); assert(x_stride >= n);
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>; using TypeConfig =
LayerNormTypeConfig<InDataType, OutDataType, SmoothScaleDataType, YScaleDataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType; using YDataType = typename TypeConfig::YDataType;
...@@ -139,12 +140,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -139,12 +140,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m}); ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m}); ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
ck_tile::HostTensor<XScaleDataType> x_scale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host({n});
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n}); ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host_dev({n});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host); ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{-1.f, 1.f}(sm_scale_host);
ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host); ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
...@@ -155,7 +156,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -155,7 +156,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
...@@ -165,7 +166,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -165,7 +166,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
x_residual_buf.ToDevice(x_residual_host.data()); x_residual_buf.ToDevice(x_residual_host.data());
x_scale_buf.ToDevice(x_scale_host.data()); sm_scale_buf.ToDevice(sm_scale_host.data());
auto prec_str = [&]() { auto prec_str = [&]() {
auto base_str = prec_i; auto base_str = prec_i;
...@@ -186,11 +187,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -186,11 +187,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< ", yr_stride:" << yr_stride << std::flush; << ", yr_stride:" << yr_stride << std::flush;
layernorm2d_fwd_traits traits{ layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant}; prec_i, prec_o, prec_sm, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr,
x_bias_buf.GetDeviceBuffer(), x_bias_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
...@@ -279,8 +280,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -279,8 +280,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(int n_ = 0; n_ < N_; n_++) for(int n_ = 0; n_ < N_; n_++)
{ {
// input smooth outlier // input smooth outlier
acc_(m_, n_) = acc_(m_, n_) = acc_(m_, n_) *
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_)); ck_tile::type_convert<ComputeDataType>(sm_scale_host(n_));
} }
} }
ComputeDataType absmax = static_cast<ComputeDataType>(0); ComputeDataType absmax = static_cast<ComputeDataType>(0);
...@@ -402,16 +403,16 @@ int main(int argc, char* argv[]) ...@@ -402,16 +403,16 @@ int main(int argc, char* argv[])
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx"); std::string prec_sm = arg_parser.get_str("prec_sm");
std::string prec_sy = arg_parser.get_str("prec_sy"); std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto") if(prec_o == "auto")
{ {
prec_o = prec_i; prec_o = prec_i;
} }
if(prec_sx == "auto") if(prec_sm == "auto")
{ {
prec_sx = "fp32"; prec_sm = "fp32";
} }
if(prec_sy == "auto") if(prec_sy == "auto")
{ {
...@@ -420,33 +421,33 @@ int main(int argc, char* argv[]) ...@@ -420,33 +421,33 @@ int main(int argc, char* argv[])
int save_mv = arg_parser.get_int("save_mv"); int save_mv = arg_parser.get_int("save_mv");
// no dynamic quant case // no dynamic quant case
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv) if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_mv)
{ {
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
save_mv) save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
} }
// dynamic quant case, only in inference // dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -8,37 +8,40 @@ ...@@ -8,37 +8,40 @@
#include "ck_tile/ops/layernorm2d.hpp" #include "ck_tile/ops/layernorm2d.hpp"
#include <string> #include <string>
template <typename InType, typename OutType, typename XScaleDataType_, typename YScaleDataType_> template <typename InType,
typename OutType,
typename SmoothSScaleDataType_,
typename YScaleDataType_>
struct LayerNormTypeConfig; struct LayerNormTypeConfig;
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_> template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleDataType_> struct LayerNormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = OutType; using YDataType = OutType;
using XBiasDataType = ck_tile::half_t; using XBiasDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t; using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t; using InvStdDataType = ck_tile::half_t;
using ComputeDataType = float; using ComputeDataType = float;
using XScaleDataType = XScaleDataType_; using SmoothScaleDataType = SmoothScaleDataType_;
using YScaleDataType = YScaleDataType_; using YScaleDataType = YScaleDataType_;
}; };
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_> template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleDataType_> struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = OutType; using YDataType = OutType;
using XBiasDataType = ck_tile::bf16_t; using XBiasDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t; using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t;
using InvStdDataType = ck_tile::bf16_t; using InvStdDataType = ck_tile::bf16_t;
using ComputeDataType = float; using ComputeDataType = float;
using XScaleDataType = XScaleDataType_; using SmoothScaleDataType = SmoothScaleDataType_;
using YScaleDataType = YScaleDataType_; using YScaleDataType = YScaleDataType_;
}; };
// runtime args // runtime args
...@@ -52,10 +55,10 @@ struct layernorm2d_fwd_traits ...@@ -52,10 +55,10 @@ struct layernorm2d_fwd_traits
std::string prec_i; // input precision std::string prec_i; // input precision
std::string prec_o; // output precision std::string prec_o; // output precision
// if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set // if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check) // can set arbitrary(will skip check)
std::string prec_sx; // x-scale, used for [1*N] input smooth quant std::string prec_sm; // x-scale, used for [1*N] input smooth quant
std::string prec_sy; // y-scale, used for [M*1] output for next layer std::string prec_sy; // y-scale, used for [M*1] output for next layer
bool save_mean_var; // bool save_mean_var; //
......
set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd")
set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
if(RMSNORM2D_FWD_ENABLE_APIS STREQUAL "all")
set(RMSNORM2D_FWD_ENABLE_APIS ${RMSNORM2D_FWD_KNOWN_APIS})
endif()
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/rmsnorm2d_fwd_blobs.txt RMSNORM2D_FWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${RMSNORM2D_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs
)
set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd") set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding ${TILE_RMSNORM2D_FWD}") message("adding ${TILE_RMSNORM2D_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp) add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp)
target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS) set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS)
......
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp" #include "ck_tile/ops/rmsnorm2d.hpp"
#include <cstring> #include <cstring>
...@@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
assert(stride >= n); assert(stride >= n);
using XDataType = DataType; using XDataType = DataType;
using YDataType = DataType; using YDataType = DataType;
using GammaDataType = DataType; using GammaDataType = DataType;
using InvRmsDataType = ck_tile::null_type; using InvRmsDataType = ck_tile::null_type;
using SmoothScaleDataType = ck_tile::null_type;
using YScaleDataType = ck_tile::null_type;
using ComputeDataType = float; using ComputeDataType = float;
...@@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
using BlockTile = ck_tile::sequence<2, 128>; using BlockTile = ck_tile::sequence<2, 128>;
using WarpTile = ck_tile::sequence<1, 64>; using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>; using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using PipelineTraits =
ck_tile::Rmsnorm2dFwdTraits<true, // kPadN
false, // kSaveInvRms
kTwoPass,
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP>; // fuse quant
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType, using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
GammaDataType, GammaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
InvRmsDataType, InvRmsDataType,
SmoothScaleDataType,
YScaleDataType,
Shape, Shape,
true, // kPadN PipelineTraits>;
false, // kSaveInvRms
kTwoPass>;
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<Problem>; using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<Problem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<Problem>; using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<Problem>;
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>; using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline>;
using Default2DEpilogueProblem = ck_tile::
Default2DEpilogueProblem<ComputeDataType, YDataType, false, PipelineTraits::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Default2DEpilogue>;
ck_tile::Rmsnorm2dFwdHostArgs args{x_buf.GetDeviceBuffer(), ck_tile::Rmsnorm2dFwdHostArgs args{x_buf.GetDeviceBuffer(),
nullptr,
nullptr,
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(),
nullptr, nullptr,
nullptr,
nullptr,
epsilon, epsilon,
m, m,
n, n,
stride,
stride,
stride,
stride}; stride};
auto kargs = Kernel::MakeKargs(args); auto kargs = Kernel::MakeKargs(args);
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import sys
from typing import List, Optional, Any
import functools
import itertools
import copy
from dataclasses import dataclass
def get_if_str(idx, total, lase_else = True):
if idx == 0:
return 'if'
elif idx < total - 1:
return 'else if'
else:
if lase_else:
return 'else'
else:
return 'else if'
FUSED_ADD_ENUM_STR_MAP = [
'no',
'pras', # pre-norm
'pra' ] # post-norm
FUSED_FUSED_SWEEP_STR_MAP = [
'no',
'sdquant', # smooth dynamic quant
'dquant' ] # dynamic quant (without sm_scale)
DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t'}
def BOOL_MAP(b_) -> str:
if b_:
return 'true'
else:
return 'false'
class rmsnorm_fwd_codegen:
API_TRAITS_DEFINE = """
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename XDataType_,
typename YDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0>
struct rmsnorm2d_fwd_traits_
{
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
};
template <typename XDataType_,
typename YDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedQuant_>
using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
YDataType_,
SmoothScaleDataType_,
YScaleDataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_,
kFusedAdd_,
kFusedQuant_>;
"""
API_COMMON_HEADER = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
#include <ck_tile/ops/epilogue.hpp>
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = rmsnorm2d_fwd_args;
{F_traits_define}
template <typename Traits_>
float rmsnorm2d_fwd_(const S& s, A a)
{{
using XDataType = typename Traits_::XDataType;
using YDataType = typename Traits_::YDataType;
using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
using YScaleDataType = typename Traits_::YScaleDataType;
using ComputeDataType = typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
using PipelineTraits =
ck_tile::Rmsnorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveInvRms,
Traits_::kTwoPass,
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
using PipelineProblem =
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvRmsDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
typename Traits_::Shape,
PipelineTraits>;
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
using Epilogue = std::conditional_t<Traits_::kFusedQuant != 0, DynamicQuantEpilogue, Default2DEpilogue>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Epilogue>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""
API_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
{F_traits_define}
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_>
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
rmsnorm2d_fwd_args a,
const ck_tile::stream_config& s)
{{
float r = -1;
{F_dispatch}
return r;
}}
"""
INSTANCE_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_api_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
{F_instance_def}
// clang-format on
"""
API_PER_DTYPE = """
{F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
{F_per_n_case}
}}
"""
API_PER_N_CASE = """
{F_if} {F_N_COND} {{
{F_inner_dispatch}
}}
"""
API_INNER_CASE = """
{F_if} {F_VEC_COND}
r={F_instance_func}(s, a);
"""
def __init__(self, working_path, kernel_filter):
self.working_path = working_path
self.kernel_filter = kernel_filter
class k_fuesd_add_enum(IntEnum):
F_NO_ADD = 0
F_PRE_ADD = 1
F_PRE_ADD_STORE_RESIDUAL = 2
class k_fused_sweep_enum(IntEnum):
F_NO_SWEEP = 0
F_RENORM = 1
F_DYNAMIC_QUANT = 2
@dataclass
class k_traits:
F_kPadN : bool
F_kSaveMeanInvStd : bool
F_kTwoPass : bool
F_kFusedAdd : Any
F_kFusedQuant : Any
@dataclass
class k_shape:
F_BlockTile : List[int]
F_WarpPerBlock : List[int]
F_WarpTile : List[int]
F_Vector_ : List[int]
@property
def F_BlockSize(self) -> int:
return functools.reduce(lambda a, b: a*b, self.F_WarpTile)
@dataclass
class k_problem:
F_XDataType : str
F_GammaDataType : str
F_ComputeDataType : str
F_YDataType : str
F_InvRmsDataType : str
F_BlockShape : str
F_Traits : Any #k_traits
@dataclass
class k_pipeline_one_pass:
F_Problem : Any #k_problem
@dataclass
class k_pipeline_two_pass:
F_Problem : Any #k_problem
@dataclass
class default_2d_epilogue_problem:
F_AccDataType : str
F_ODataType : str
F_kPadM : bool
F_kPadN : bool
@dataclass
class default_2d_epilogue:
F_problem : Any
@dataclass
class k_kernel:
F_pipeline : Any
F_epilogue : Any
@dataclass
class h_traits:
F_XDataType : str
F_YDataType : str
F_SmoothScaleDataType : str
F_YScaleDataType : str
F_Repeat_M : int
F_Repeat_N : int
F_ThreadPerBlock_M : int
F_ThreadPerBlock_N : int
F_Vector_N : int
F_kPadN : bool
F_kSaveInvRms : bool
F_kTwoPass : bool
F_kFusedAdd : int
F_kFusedQuant : int
@property
def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_
# string when calling this kernel
@property
def call_name(self) -> str:
return f'rmsnorm2d_fwd_<traits_<{self.trait_name}>>'
# string when define this kernel
@property
def def_name(self) -> str:
return f'template float rmsnorm2d_fwd_<traits_<{self.trait_name}>>(const S&, A);'
# this class hold kernel under same source file
@dataclass
class h_instance:
F_DataTypePair : str
F_N : str
F_add : int
F_sweep : int
instance_list : List[Any] # List[h_traits]
@property
def name(self) -> str:
prec_i, prec_o = self.F_DataTypePair.split(',')
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
nnn = f'rmsnorm2d_fwd_{dtype_str}_n{self.F_N}'
if self.F_add != 0:
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
if self.F_sweep != 0:
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
return nnn
@property
def instance_name(self) ->str:
return self.name
@property
def content(self) ->str:
instance_defs = ''
for ins in self.instance_list:
instance_defs += ins.def_name + '\n'
return rmsnorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs)
@property
def name_api(self) -> str:
return 'rmsnorm2d_fwd_api'
@property
def name_common_header(self) -> str:
return 'rmsnorm2d_fwd_api_common'
@property
def content_api(self) -> str:
# 1 sort based on dtype
t_dtype_dict = dict()
blobs = self.get_blobs()
for blob in blobs:
if blob.F_DataTypePair not in t_dtype_dict:
t_dtype_dict[blob.F_DataTypePair] = {}
if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]:
t_dtype_dict[blob.F_DataTypePair][blob.F_N] = []
t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob)
d_str = ''
for i_d, dtype_ in enumerate(t_dtype_dict):
blob_per_t = t_dtype_dict[dtype_]
n_str = ''
for i_n, n_ in enumerate(blob_per_t):
blob_per_n = blob_per_t[n_]
inner_str = ""
for i_b, b_ in enumerate(blob_per_n):
# generate single kernel instance file
#vec_str = ""
for i_ins, ins in enumerate(b_.instance_list):
idx_in_n = i_b * len(b_.instance_list) + i_ins
len_in_n = len(blob_per_n) * len(b_.instance_list)
# _if = 'if' if i_ins == 0 else 'else if'
if ins.F_kFusedQuant == 0:
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
elif ins.F_kFusedQuant == 1:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType)
elif ins.F_kFusedQuant == 2:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
f_sweep_cond = _sweep_cond)
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name)
#inner_str = inner_str + vec_str
n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else ''
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
prec_i, prec_o = dtype_.split(',')
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str)
return api_base
@property
def content_common_header(self) -> str:
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
def get_blobs(self):
h_traits = rmsnorm_fwd_codegen.h_traits
h_instance = rmsnorm_fwd_codegen.h_instance
dynamic_quant_out_dtype = ['int8']
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
fused_add_list = [0, 1]
fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
# rm rn tm tn vn pd mv 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]}
total_blob = list()
for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key]
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list):
prec_i, prec_o = dtype.split(',')
scale_sm, scale_y = scale_type.split(',')
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
continue # skip non dynamic quant case
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
continue
current_hs = list()
for chs_ in hs:
h_ = copy.copy(chs_) # copy the base instance out
h_.F_XDataType = prec_i
h_.F_YDataType = prec_o
h_.F_SmoothScaleDataType = scale_sm
h_.F_YScaleDataType = scale_y
h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant
current_hs.append(h_) # + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str = 'big' if hs_key == 'big' else current_n
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs))
return total_blob
def list_blobs(self) -> None:
w_p = Path(self.working_path)
list_p = w_p / 'rmsnorm2d_fwd_blobs.txt'
blobs = self.get_blobs()
with list_p.open('w') as list_f:
# api related file
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
# kernel instance file
for b in blobs:
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
def gen_blobs(self) -> None:
w_p = Path(self.working_path)
(w_p / (self.name_api + ".cpp")).write_text(self.content_api)
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
blobs = self.get_blobs()
for b in blobs:
(w_p / (b.name + ".cpp")).write_text(b.content)
def list_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
rmsnorm_fwd_codegen(args.working_path, args.filter).list_blobs()
def gen_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
rmsnorm_fwd_codegen(args.working_path, args.filter).gen_blobs()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK rmsnorm kernel",
)
parser.add_argument(
"-a",
"--api",
default='fwd[all]',
required=False,
help="supply API(s) to generate (default: fwd). separated by comma."
)
# the directory for list_blobs/gen_blobs to write files into
parser.add_argument(
"-w",
"--working_path",
default="./",
required=False,
help="the path where all the blobs are going to be generated"
)
# this script have 2 modes
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
# this is useful in build system like cmake to construct source code dependency, by
# reading the content out of this file
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
# like FA, only need to use this mode
parser.add_argument(
"-l",
"--list_blobs",
action='store_true',
help="list all the kernels to a file, "
)
parser.add_argument(
"-g",
"--gen_blobs",
action='store_true',
help="generate all kernels into different tile"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
required=False,
help="filter out kernels that need to generate, using fnmatch module"
)
parser.add_argument(
"-t",
"--traits",
default="all",
required=False,
help="enable/disable some feature. default generate all"
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="codegen receipt."
)
args = parser.parse_args()
# print(f'{args.list_blobs}-{args.gen_blobs}')
if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)):
print('gen_blobs/list_blobs must specify only one option')
sys.exit()
p = Path(args.working_path)
if not p.exists():
p.mkdir()
if args.list_blobs:
list_blobs(args)
else:
gen_blobs(args)
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
using trait_ = rmsnorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_>;
template <typename data_type>
float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
rmsnorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
float r = -1;
// clang-format off
// rm rn tm tn vn pd rms 2p
if(a.n <= 64) {
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 128) {
if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 256) {
if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 512) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 768) {
if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 1024) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 1536) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 2048) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 8, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 3072) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, false>>(s, a);
}
else if(a.n <= 4096) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, false>>(s, a);
}
else if(a.n > 4096) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, true>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, true>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, true>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a);
}
return r;
// clang-format on
}
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
return rmsnorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return rmsnorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
}
else
throw std::runtime_error("Without supported instances!");
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment