"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "690c75a7eb7012bf0fd6fb3f6e129e83fbcbdb53"
Unverified Commit fcba889e authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE] fix some rand number init (#1287)

* add random norm

* normalized default to 0/3

* change squant->auto
parent 8346af9c
...@@ -44,9 +44,9 @@ args: ...@@ -44,9 +44,9 @@ args:
-range_v per-tensor quantization range of v. used if squant=1. (default:16) -range_v per-tensor quantization range of v. used if squant=1. (default:16)
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
-squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0) -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p, 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
scale_o according to range_q, range_k, range_v, range_p, range_o calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
-iperm permute input (default:1) -iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1) -operm permute output (default:1)
...@@ -64,8 +64,11 @@ args: ...@@ -64,8 +64,11 @@ args:
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
-lse 0 not store lse, 1 store lse (default:0) -lse 0 not store lse, 1 store lse (default:0)
-kname if set to 1 will print kernel name (default:0) -kname if set to 1 will print kernel name (default:0)
-init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1) -init init method. ui, uniform random int, ni, normalized random int (default:uf)
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
``` ```
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
......
...@@ -60,12 +60,14 @@ auto create_args(int argc, char* argv[]) ...@@ -60,12 +60,14 @@ auto create_args(int argc, char* argv[])
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
.insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.")
.insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.")
.insert( .insert("squant",
"squant", "auto",
"0", "if using static quantization fusion or not. auto: fp8 will default use squant, "
"if using static quantization fusion or not. 0: original flow(not prefered)\n" "other will not\n"
"1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,\n" "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
"scale_o according to range_q, range_k, range_v, range_p, range_o") "P and O.\n"
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
"range_p, range_o")
.insert("iperm", .insert("iperm",
"1", "1",
"permute input\n" "permute input\n"
...@@ -92,8 +94,11 @@ auto create_args(int argc, char* argv[]) ...@@ -92,8 +94,11 @@ auto create_args(int argc, char* argv[])
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("lse", "0", "0 not store lse, 1 store lse") .insert("lse", "0", "0 not store lse, 1 store lse")
.insert("kname", "0", "if set to 1 will print kernel name") .insert("kname", "0", "if set to 1 will print kernel name")
.insert( .insert("init",
"init", "1", "init method. 0:random int, 1:random float, 2:trig float, 3:quantization") "uf",
"init method. ui, uniform random int, ni, normalized random int\n"
"uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, "
"quantization")
.insert("seed", .insert("seed",
"11939", "11939",
"random seed used for initializing input tensors. 0 for " "random seed used for initializing input tensors. 0 for "
...@@ -107,7 +112,7 @@ auto create_args(int argc, char* argv[]) ...@@ -107,7 +112,7 @@ auto create_args(int argc, char* argv[])
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataType>
auto get_elimit(int /*init_method*/) auto get_elimit(std::string /*init_method*/)
{ {
double rtol = 1e-3; double rtol = 1e-3;
double atol = 1e-3; double atol = 1e-3;
...@@ -115,9 +120,15 @@ auto get_elimit(int /*init_method*/) ...@@ -115,9 +120,15 @@ auto get_elimit(int /*init_method*/)
} }
template <> template <>
auto get_elimit<ck_tile::bf16_t>(int init_method) auto get_elimit<ck_tile::bf16_t>(std::string init_method)
{ {
if(init_method == 0) if(init_method == "ui" || init_method == "ni")
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
else if(init_method == "nf")
{ {
double rtol = 1e-2; double rtol = 1e-2;
double atol = 1e-2; double atol = 1e-2;
...@@ -132,9 +143,9 @@ auto get_elimit<ck_tile::bf16_t>(int init_method) ...@@ -132,9 +143,9 @@ auto get_elimit<ck_tile::bf16_t>(int init_method)
} }
template <> template <>
auto get_elimit<ck_tile::fp8_t>(int init_method) auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{ {
if(init_method == 0) if(init_method == "ui" || init_method == "ni")
{ {
unsigned max_rounding_point_distance = 0; unsigned max_rounding_point_distance = 0;
double atol = 2e-3; double atol = 2e-3;
...@@ -182,15 +193,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -182,15 +193,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(scale_s == .0f) if(scale_s == .0f)
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ? scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
bool squant = arg_parser.get_bool("squant"); std::string squant_str = arg_parser.get_str("squant");
if constexpr(!std::is_same_v<DataType, ck_tile::fp8_t>) bool squant = [&]() {
{ if(squant_str == "auto")
if(squant)
{ {
std::cerr << "static quantization only support fp8 for now" << std::endl; if(data_type == "fp8")
return false; return true;
else
return false;
} }
} else
return atoi(squant_str.c_str()) != 0 ? true : false;
}();
float range_q = arg_parser.get_float("range_q"); float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k"); float range_k = arg_parser.get_float("range_k");
...@@ -217,7 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -217,7 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias_info bias = bias_info::decode(arg_parser.get_str("bias")); bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
int init_method = arg_parser.get_int("init"); std::string init_method = arg_parser.get_str("init");
std::optional<uint32_t> seed = arg_parser.get_uint32("seed"); std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
if(*seed == 0) if(*seed == 0)
{ {
...@@ -319,28 +333,43 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -319,28 +333,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host( ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
if(init_method == 0) if(init_method == "ui" || init_method == "0")
{ {
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host); ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host); ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host); ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host); ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
} }
else if(init_method == 1) else if(init_method == "ni")
{
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
}
else if(init_method == "uf" || init_method == "1")
{ {
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host); ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host); ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host); ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host); ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
} }
else if(init_method == 2) else if(init_method == "nf")
{
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
}
else if(init_method == "tf" || init_method == "2")
{ {
ck_tile::FillTrigValue<QDataType>{}(q_host); ck_tile::FillTrigValue<QDataType>{}(q_host);
ck_tile::FillTrigValue<KDataType>{}(k_host); ck_tile::FillTrigValue<KDataType>{}(k_host);
ck_tile::FillTrigValue<VDataType>{}(v_host); ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host); ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
} }
else if(init_method == 3) // suitable for fp8 quantization else if(init_method == "ufq" || init_method == "uf:q" ||
init_method == "3") // suitable for fp8 quantization
{ {
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host); ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment