Commit 309705c1 authored by coderfeli's avatar coderfeli
Browse files

add debug datra

parent a6b761c3
......@@ -155,7 +155,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int fused_quant = arg_parser.get_int("fquant");
int gate_only = arg_parser.get_int("gate_only");
int api = arg_parser.get_int("api");
int balance = arg_parser.get_int("balance");
// int balance = arg_parser.get_int("balance");
int tp = arg_parser.get_int("tp");
int init = arg_parser.get_int("init");
uint32_t seed = arg_parser.get_uint32("seed");
......@@ -257,14 +257,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else if(init == 2)
{
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed, true}(a_host);
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed, true}(g_host);
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed, true}(sa_host);
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed, true}(sg_host);
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
ck_tile::FillNormalDistribution<ADataType>{1.f, 1.f, seed, true}(a_host);
ck_tile::FillNormalDistribution<GDataType>{1.f, 1.f, seed, true}(g_host);
ck_tile::FillNormalDistribution<DDataType>{1.f, 1.f, seed, true}(d_host);
ck_tile::FillNormalDistribution<AScaleDataType>{1.f, 1.f, seed, true}(sa_host);
ck_tile::FillNormalDistribution<GScaleDataType>{1.f, 1.f, seed, true}(sg_host);
ck_tile::FillNormalDistribution<DScaleDataType>{1.f, 1.f, seed, true}(sd_host);
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{1.f, 1.f, seed, true}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.125f, 0.125f, seed, true}(topk_weight_host);
}
// permute weight
......@@ -272,15 +272,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting
if(balance)
if(1)
{
int e_cnt = 0;
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
{
topk_ids_host.mData[i] = e_cnt;
e_cnt++;
if(e_cnt >= experts)
e_cnt = 0;
for(int i=0; i < topk; i++) {
topk_ids_host.mData[i] = i;
topk_weight_host.mData[i] = 0.1;
}
}
else
......@@ -420,7 +416,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts,
block_m);
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Silu>(
a_host,
g_host,
d_host,
......@@ -529,7 +525,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Silu>(
a_host,
g_host,
d_host,
......
......@@ -62,6 +62,7 @@ check_err(const Range& out,
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
};
printf("1090111\n");
bool res{true};
int err_count = 0;
......@@ -76,7 +77,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < 500)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......@@ -107,6 +108,7 @@ check_err(const Range& out,
double atol = 1e-3,
bool allow_infinity_ref = false)
{
printf("1111\n");
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
......@@ -127,6 +129,7 @@ check_err(const Range& out,
double err = 0;
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
double max_err = std::numeric_limits<float>::min();
int print_cnt = 0;
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
......@@ -136,12 +139,16 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < 500)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
} else if (print_cnt < 10) {
print_cnt++;
std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
}
if(!res)
......@@ -195,7 +202,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < 500)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......@@ -235,6 +242,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
<< std::endl;
return false;
}
printf("222\n");
bool res{true};
int err_count = 0;
......@@ -250,7 +258,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < 500)
{
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl;
......@@ -313,6 +321,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
bool res{true};
int err_count = 0;
double err = 0;
printf("11113\n");
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
......@@ -327,7 +336,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < 500)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
......@@ -381,7 +390,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < 500)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment