Commit 309705c1 authored by coderfeli's avatar coderfeli
Browse files

add debug datra

parent a6b761c3
...@@ -155,7 +155,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -155,7 +155,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int fused_quant = arg_parser.get_int("fquant"); int fused_quant = arg_parser.get_int("fquant");
int gate_only = arg_parser.get_int("gate_only"); int gate_only = arg_parser.get_int("gate_only");
int api = arg_parser.get_int("api"); int api = arg_parser.get_int("api");
int balance = arg_parser.get_int("balance"); // int balance = arg_parser.get_int("balance");
int tp = arg_parser.get_int("tp"); int tp = arg_parser.get_int("tp");
int init = arg_parser.get_int("init"); int init = arg_parser.get_int("init");
uint32_t seed = arg_parser.get_uint32("seed"); uint32_t seed = arg_parser.get_uint32("seed");
...@@ -257,14 +257,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -257,14 +257,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else if(init == 2) else if(init == 2)
{ {
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed, true}(a_host); ck_tile::FillNormalDistribution<ADataType>{1.f, 1.f, seed, true}(a_host);
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed, true}(g_host); ck_tile::FillNormalDistribution<GDataType>{1.f, 1.f, seed, true}(g_host);
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host); ck_tile::FillNormalDistribution<DDataType>{1.f, 1.f, seed, true}(d_host);
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed, true}(sa_host); ck_tile::FillNormalDistribution<AScaleDataType>{1.f, 1.f, seed, true}(sa_host);
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed, true}(sg_host); ck_tile::FillNormalDistribution<GScaleDataType>{1.f, 1.f, seed, true}(sg_host);
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host); ck_tile::FillNormalDistribution<DScaleDataType>{1.f, 1.f, seed, true}(sd_host);
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host); ck_tile::FillNormalDistribution<YSmoothScaleDataType>{1.f, 1.f, seed, true}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host); ck_tile::FillNormalDistribution<TopkWeightDataType>{0.125f, 0.125f, seed, true}(topk_weight_host);
} }
// permute weight // permute weight
...@@ -272,15 +272,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -272,15 +272,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting // do moe sorting
if(balance) if(1)
{ {
int e_cnt = 0; for(int i=0; i < topk; i++) {
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) topk_ids_host.mData[i] = i;
{ topk_weight_host.mData[i] = 0.1;
topk_ids_host.mData[i] = e_cnt;
e_cnt++;
if(e_cnt >= experts)
e_cnt = 0;
} }
} }
else else
...@@ -420,7 +416,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -420,7 +416,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts, experts,
block_m); block_m);
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Silu>(
a_host, a_host,
g_host, g_host,
d_host, d_host,
...@@ -529,7 +525,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -529,7 +525,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Silu>(
a_host, a_host,
g_host, g_host,
d_host, d_host,
......
...@@ -62,6 +62,7 @@ check_err(const Range& out, ...@@ -62,6 +62,7 @@ check_err(const Range& out,
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
}; };
printf("1090111\n");
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
...@@ -76,7 +77,7 @@ check_err(const Range& out, ...@@ -76,7 +77,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 500)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -107,6 +108,7 @@ check_err(const Range& out, ...@@ -107,6 +108,7 @@ check_err(const Range& out,
double atol = 1e-3, double atol = 1e-3,
bool allow_infinity_ref = false) bool allow_infinity_ref = false)
{ {
printf("1111\n");
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
...@@ -127,6 +129,7 @@ check_err(const Range& out, ...@@ -127,6 +129,7 @@ check_err(const Range& out,
double err = 0; double err = 0;
// TODO: This is a hack. We should have proper specialization for bf16_t data type. // TODO: This is a hack. We should have proper specialization for bf16_t data type.
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
int print_cnt = 0;
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = type_convert<float>(*std::next(std::begin(out), i)); const double o = type_convert<float>(*std::next(std::begin(out), i));
...@@ -136,12 +139,16 @@ check_err(const Range& out, ...@@ -136,12 +139,16 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 500)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
} }
res = false; res = false;
} else if (print_cnt < 10) {
print_cnt++;
std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
} }
} }
if(!res) if(!res)
...@@ -195,7 +202,7 @@ check_err(const Range& out, ...@@ -195,7 +202,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 500)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -235,6 +242,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -235,6 +242,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
<< std::endl; << std::endl;
return false; return false;
} }
printf("222\n");
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
...@@ -250,7 +258,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -250,7 +258,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 500)
{ {
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl; << std::endl;
...@@ -313,6 +321,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -313,6 +321,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
printf("11113\n");
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
...@@ -327,7 +336,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -327,7 +336,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 500)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl; << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
...@@ -381,7 +390,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -381,7 +390,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 500)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment