Commit fd92d7eb authored by kentqian's avatar kentqian
Browse files

Trigged dynamic layernorm on

parent d83e2d25
...@@ -34,6 +34,7 @@ FUSED_ADD_ENUM_STR_MAP = [ ...@@ -34,6 +34,7 @@ FUSED_ADD_ENUM_STR_MAP = [
FUSED_FUSED_SWEEP_STR_MAP = [ FUSED_FUSED_SWEEP_STR_MAP = [
'no', 'no',
'smoothdquant',
'dquant' ] 'dquant' ]
DATA_TYPE_MAP = {'fp32' : 'float', DATA_TYPE_MAP = {'fp32' : 'float',
...@@ -223,7 +224,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -223,7 +224,7 @@ float layernorm2d_fwd_(const S& s, A a)
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
using Epilogue = std::conditional_t<Traits_::kFusedQuant == 1, DynamicQuantEpilogue, Default2DEpilogue>; using Epilogue = std::conditional_t<(Traits_::kFusedQuant >= 1), DynamicQuantEpilogue, Default2DEpilogue>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>; using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
...@@ -504,12 +505,11 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -504,12 +505,11 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
scale_list = [('fp32,fp32')] scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'), dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out ('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
#bias_list = [0, 1]
#fused_add_list = [0, 1, 2] #fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
bias_list = [0, 1] bias_list = [0, 1]
fused_add_list = [0, 1] fused_add_list = [0, 1]
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv fdiv 2p bias add sweep # rm rn tm tn vn pd mv fdiv 2p bias add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0, 0), h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0, 0),
...@@ -567,9 +567,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -567,9 +567,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
for dtype, scale_type, bias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, bias_list, fused_add_list, fused_sweep_list): for dtype, scale_type, bias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, bias_list, fused_add_list, fused_sweep_list):
prec_i, prec_o = dtype.split(',') prec_i, prec_o = dtype.split(',')
scale_x, scale_y = scale_type.split(',') scale_x, scale_y = scale_type.split(',')
if prec_o in dynamic_quant_out_dtype and fused_quant != 1: if prec_o in dynamic_quant_out_dtype and fused_quant == 0:
continue # skip non dynamic quant case continue # skip non dynamic quant case
if fused_quant == 1 and hs_key == 'big': if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
continue continue
current_hs = list() current_hs = list()
for chs_ in hs: for chs_ in hs:
......
...@@ -102,6 +102,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -102,6 +102,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl;
return false; return false;
} }
if(fused_quant == 2 && prec_o != "int8")
{
std::cout << "if fused_quant is 2, only support \"-prec_o=int8\" case" << std::endl;
return false;
}
assert(x_stride >= n); assert(x_stride >= n);
......
...@@ -127,11 +127,13 @@ struct DynamicQuantEpilogue ...@@ -127,11 +127,13 @@ struct DynamicQuantEpilogue
auto o_acc_tmp = o_acc_tile; auto o_acc_tmp = o_acc_tile;
sweep_tile(o_acc_tmp, [&](auto idx) { if constexpr(!std::is_same_v<remove_cvref_t<decltype(x_scale)>, ck_tile::null_tensor>){
constexpr auto j_idx = make_tuple(idx[number<1>{}]); sweep_tile(o_acc_tmp, [&](auto idx) {
const auto xs_ = type_convert<AccDataType>(x_scale[j_idx]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_; const auto xs_ = type_convert<AccDataType>(x_scale[j_idx]);
}); o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
});
}
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
......
...@@ -46,8 +46,8 @@ enum class Layernorm2dFusedQuantEnum ...@@ -46,8 +46,8 @@ enum class Layernorm2dFusedQuantEnum
// clang-format off // clang-format off
template<Layernorm2dFusedQuantEnum> struct Layernorm2dFusedQuantEnumName; template<Layernorm2dFusedQuantEnum> struct Layernorm2dFusedQuantEnumName;
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; }; template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; }; template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
// clang-format on // clang-format on
template <bool kPadN_, template <bool kPadN_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment