Unverified Commit 64d5c4d6 authored by ruanjm's avatar ruanjm Committed by GitHub
Browse files

Implement fp8 quant for layernorm and rmsnorm (#1814)

parent 5b9b083d
...@@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) ...@@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)
target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
......
...@@ -39,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [ ...@@ -39,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
DATA_TYPE_MAP = {'fp32' : 'float', DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::fp16_t', 'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t', 'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t'} 'int8' : 'ck_tile::int8_t',
'fp8' : 'ck_tile::fp8_t'}
def BOOL_MAP(b_) -> str: def BOOL_MAP(b_) -> str:
if b_: if b_:
...@@ -504,12 +505,13 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -504,12 +505,13 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_traits = layernorm_fwd_codegen.h_traits h_traits = layernorm_fwd_codegen.h_traits
h_instance = layernorm_fwd_codegen.h_instance h_instance = layernorm_fwd_codegen.h_instance
dynamic_quant_out_dtype = ['int8'] dynamic_quant_out_dtype = ['int8', 'fp8']
# some predefined support range # some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict # (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list = [('fp32,fp32')] scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'), dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out ('fp16,int8'), ('bf16,int8'),
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 or fp8 out
types_8bit = ('int8', 'fp8') types_8bit = ('int8', 'fp8')
types_16bit = ('int16', 'fp16', 'bf16') types_16bit = ('int16', 'fp16', 'bf16')
#fused_add_list = [0, 1, 2] #fused_add_list = [0, 1, 2]
......
...@@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>() ...@@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>()
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template <>
auto get_elimit<ck_tile::int8_t>()
{
double rtol = 1e-2;
double atol = 1.0;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
...@@ -97,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -97,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
int xbias = arg_parser.get_int("xbias"); int xbias = arg_parser.get_int("xbias");
int fused_add = arg_parser.get_int("fadd"); int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant"); int fused_quant = arg_parser.get_int("fquant");
if(fused_quant == 1 && prec_o != "int8") if(fused_quant == 1 && prec_o != "int8" && prec_o != "fp8")
{ {
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; std::cout
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
<< std::endl;
return false; return false;
} }
...@@ -291,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -291,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax = a > absmax ? a : absmax; absmax = a > absmax ? a : absmax;
} }
// printf("cpu:absmax:%f\n", absmax); // printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0); constexpr ComputeDataType kMaxY =
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
: 0.0;
ComputeDataType y_scale = absmax / kMaxY;
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale); y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++) for(int n_ = 0; n_ < N_; n_++)
{ {
...@@ -334,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -334,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_residual_buf.FromDevice(y_residual_host_dev.data()); y_residual_buf.FromDevice(y_residual_host_dev.data());
} }
auto [rtol, atol] = get_elimit<InDataType>(); auto [rtol, atol] = get_elimit<OutDataType>();
if(x_stride == n) if(x_stride == n)
{ {
...@@ -452,6 +466,16 @@ int main(int argc, char* argv[]) ...@@ -452,6 +466,16 @@ int main(int argc, char* argv[])
{ {
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
return -3; return -3;
} }
#!/bin/sh #!/bin/sh
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8"; do for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=1 -prec_o=fp8"; do
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
......
...@@ -33,7 +33,7 @@ target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS}) ...@@ -33,7 +33,7 @@ target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS) set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)
target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})
......
...@@ -37,7 +37,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [ ...@@ -37,7 +37,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
DATA_TYPE_MAP = {'fp32' : 'float', DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::fp16_t', 'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t', 'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t'} 'int8' : 'ck_tile::int8_t',
'fp8' : 'ck_tile::fp8_t'}
def BOOL_MAP(b_) -> str: def BOOL_MAP(b_) -> str:
if b_: if b_:
...@@ -477,12 +478,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -477,12 +478,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_traits = rmsnorm_fwd_codegen.h_traits h_traits = rmsnorm_fwd_codegen.h_traits
h_instance = rmsnorm_fwd_codegen.h_instance h_instance = rmsnorm_fwd_codegen.h_instance
dynamic_quant_out_dtype = ['int8'] dynamic_quant_out_dtype = ['int8', 'fp8']
# some predefined support range # some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict # (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list = [('fp32,fp32')] scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'), dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out ('fp16,int8'), ('bf16,int8'),
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out
#fused_add_list = [0, 1, 2] #fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
fused_add_list = [0, 1] fused_add_list = [0, 1]
......
...@@ -105,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -105,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sy = "fp32"; prec_sy = "fp32";
} }
if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8") if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8" && prec_o != "fp8")
{ {
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; std::cout
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
<< std::endl;
return false; return false;
} }
...@@ -248,7 +250,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -248,7 +250,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax = a > absmax ? a : absmax; absmax = a > absmax ? a : absmax;
} }
// printf("cpu:absmax:%f\n", absmax); // printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0); constexpr ComputeDataType kMaxY =
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
: 0.0;
ComputeDataType y_scale = absmax / kMaxY;
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale); y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++) for(int n_ = 0; n_ < N_; n_++)
{ {
...@@ -400,6 +406,16 @@ int main(int argc, char* argv[]) ...@@ -400,6 +406,16 @@ int main(int argc, char* argv[])
{ {
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_rms)
{
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_rms)
{
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
return -3; return -3;
} }
#!/bin/sh #!/bin/sh
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8"; do for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"; do
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
...@@ -27,7 +27,7 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 ...@@ -27,7 +27,7 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done done
done done
......
...@@ -443,7 +443,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -443,7 +443,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
} }
return res; return res;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment