Commit 779e6525 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 51452c03
...@@ -544,13 +544,14 @@ auto get_flops_funcs() ...@@ -544,13 +544,14 @@ auto get_flops_funcs()
static std::unordered_map<std::string, op_flops> op_funcs; static std::unordered_map<std::string, op_flops> op_funcs;
op_funcs.emplace("gemm", [&](const std::vector<shape>& vec_ss) { op_funcs.emplace("gemm", [&](const std::vector<shape>& vec_ss) {
assert(vec_ss.size() >= 2); assert(vec_ss.size() >= 2);
auto sa = vec_ss.front(); auto sa = vec_ss.front();
auto sb = vec_ss.at(1); auto sb = vec_ss.at(1);
auto batch = 1; auto batch = 1;
auto lens_a = sa.lens(); auto lens_a = sa.lens();
batch = std::accumulate(lens_a.rbegin() + 2, lens_a.rend(), 1, std::multiplies<std::size_t>{}); batch =
auto m = lens_a[lens_a.size() - 2]; std::accumulate(lens_a.rbegin() + 2, lens_a.rend(), 1, std::multiplies<std::size_t>{});
auto k = lens_a.back(); auto m = lens_a[lens_a.size() - 2];
auto k = lens_a.back();
auto lens_b = sb.lens(); auto lens_b = sb.lens();
assert(k == lens_b[lens_b.size() - 2]); assert(k == lens_b[lens_b.size() - 2]);
auto n = lens_b.back(); auto n = lens_b.back();
...@@ -632,7 +633,7 @@ void program::perf_report(std::ostream& os, ...@@ -632,7 +633,7 @@ void program::perf_report(std::ostream& os,
this->print(names, [&](auto ins, auto ins_names) { this->print(names, [&](auto ins, auto ins_names) {
std::stringstream ss; std::stringstream ss;
instruction::print(ss, ins, ins_names); instruction::print(ss, ins, ins_names);
if (max_ins_len < ss.str().length()) if(max_ins_len < ss.str().length())
{ {
max_ins_len = ss.str().length(); max_ins_len = ss.str().length();
} }
...@@ -643,12 +644,12 @@ void program::perf_report(std::ostream& os, ...@@ -643,12 +644,12 @@ void program::perf_report(std::ostream& os,
}); });
int print_ins_len = max_ins_len + 4; int print_ins_len = max_ins_len + 4;
std::string str = "Instructions"; std::string str = "Instructions";
os << str; os << str;
print_space(os, print_ins_len - str.length()); print_space(os, print_ins_len - str.length());
std::string time_str = "Time(ms) "; std::string time_str = "Time(ms) ";
std::string percent_str = "Percentage "; std::string percent_str = "Percentage ";
std::string flops_str = "Flops(GFlops/s)"; std::string flops_str = "Flops(GFlops/s)";
os << time_str << percent_str << flops_str << std::endl; os << time_str << percent_str << flops_str << std::endl;
auto flops_funcs = get_flops_funcs(); auto flops_funcs = get_flops_funcs();
...@@ -663,13 +664,13 @@ void program::perf_report(std::ostream& os, ...@@ -663,13 +664,13 @@ void program::perf_report(std::ostream& os,
print_space(os, print_ins_len - ss.str().length()); print_space(os, print_ins_len - ss.str().length());
double avg = common_average(ins_vec[ins]); double avg = common_average(ins_vec[ins]);
std::string tms = std::to_string(avg); std::string tms = std::to_string(avg);
tms.append(time_str.length() - tms.length(), ' '); tms.append(time_str.length() - tms.length(), ' ');
double percent = 100.0 * avg / total_instruction_time; double percent = 100.0 * avg / total_instruction_time;
std::string pers = std::to_string(percent); std::string pers = std::to_string(percent);
auto loc = pers.find('.'); auto loc = pers.find('.');
if (loc != std::string::npos) if(loc != std::string::npos)
{ {
pers.erase(pers.begin() + loc + 6, pers.end()); pers.erase(pers.begin() + loc + 6, pers.end());
} }
...@@ -678,19 +679,19 @@ void program::perf_report(std::ostream& os, ...@@ -678,19 +679,19 @@ void program::perf_report(std::ostream& os,
// calculate flops // calculate flops
std::string flps; std::string flps;
std::string op_name = ins->name(); std::string op_name = ins->name();
auto nloc = op_name.find("::"); auto nloc = op_name.find("::");
op_name.erase(op_name.begin(), op_name.begin() + nloc + 2); op_name.erase(op_name.begin(), op_name.begin() + nloc + 2);
if (contains(flops_funcs, op_name)) if(contains(flops_funcs, op_name))
{ {
auto op_flop_func = flops_funcs.at(op_name); auto op_flop_func = flops_funcs.at(op_name);
auto inss = to_shapes(ins->inputs()); auto inss = to_shapes(ins->inputs());
double flops = op_flop_func(inss); double flops = op_flop_func(inss);
flops /= avg; flops /= avg;
// convert to GFlops // convert to GFlops
flops /= 1.0e6; flops /= 1.0e6;
flps = std::to_string(flops); flps = std::to_string(flops);
auto floc = flps.find('.'); auto floc = flps.find('.');
if (floc != std::string::npos) if(floc != std::string::npos)
{ {
flps.erase(flps.begin() + floc + 4, flps.end()); flps.erase(flps.begin() + floc + 4, flps.end());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment