Commit 51452c03 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine perf report and add flops for gemm op

parent d0543c96
#include <functional>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -15,10 +16,12 @@ ...@@ -15,10 +16,12 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp> #include <migraphx/marker.hpp>
#include <iostream> #include <iostream>
#include <numeric>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <utility> #include <utility>
#include <iomanip>
#include <unordered_set> #include <unordered_set>
#include <map> #include <map>
...@@ -526,6 +529,38 @@ void program::mark(const parameter_map& params, marker&& m) ...@@ -526,6 +529,38 @@ void program::mark(const parameter_map& params, marker&& m)
m.mark_stop(*this); m.mark_stop(*this);
} }
static void print_space(std::ostream& os, int n)
{
for(int i = 0; i < n; ++i)
{
os << ' ';
}
}
using op_flops = std::function<double(const std::vector<shape>& vec_ss)>;
auto get_flops_funcs()
{
static std::unordered_map<std::string, op_flops> op_funcs;
op_funcs.emplace("gemm", [&](const std::vector<shape>& vec_ss) {
assert(vec_ss.size() >= 2);
auto sa = vec_ss.front();
auto sb = vec_ss.at(1);
auto batch = 1;
auto lens_a = sa.lens();
batch = std::accumulate(lens_a.rbegin() + 2, lens_a.rend(), 1, std::multiplies<std::size_t>{});
auto m = lens_a[lens_a.size() - 2];
auto k = lens_a.back();
auto lens_b = sb.lens();
assert(k == lens_b[lens_b.size() - 2]);
auto n = lens_b.back();
return 2.0 * m * n * k * batch;
});
return op_funcs;
}
void program::perf_report(std::ostream& os, void program::perf_report(std::ostream& os,
std::size_t n, std::size_t n,
parameter_map params, parameter_map params,
...@@ -591,17 +626,76 @@ void program::perf_report(std::ostream& os, ...@@ -591,17 +626,76 @@ void program::perf_report(std::ostream& os,
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
// count max instruction length
int max_ins_len = 0;
this->print(names, [&](auto ins, auto ins_names) {
std::stringstream ss;
instruction::print(ss, ins, ins_names);
if (max_ins_len < ss.str().length())
{
max_ins_len = ss.str().length();
}
// skip return instruction
if(ins->name() == "@return")
return;
});
int print_ins_len = max_ins_len + 4;
std::string str = "Instructions";
os << str;
print_space(os, print_ins_len - str.length());
std::string time_str = "Time(ms) ";
std::string percent_str = "Percentage ";
std::string flops_str = "Flops(GFlops/s)";
os << time_str << percent_str << flops_str << std::endl;
auto flops_funcs = get_flops_funcs();
this->print(names, [&](auto ins, auto ins_names) { this->print(names, [&](auto ins, auto ins_names) {
instruction::print(std::cout, ins, ins_names); std::stringstream ss;
instruction::print(ss, ins, ins_names);
os << ss.str();
// skip return instruction // skip return instruction
if(ins->name() == "@return") if(ins->name() == "@return")
return; return;
print_space(os, print_ins_len - ss.str().length());
double avg = common_average(ins_vec[ins]); double avg = common_average(ins_vec[ins]);
double percent = std::ceil(100.0 * avg / total_instruction_time); std::string tms = std::to_string(avg);
os << ": " << avg << "ms, " << percent << "%"; tms.append(time_str.length() - tms.length(), ' ');
os << std::endl; double percent = 100.0 * avg / total_instruction_time;
std::string pers = std::to_string(percent);
auto loc = pers.find('.');
if (loc != std::string::npos)
{
pers.erase(pers.begin() + loc + 6, pers.end());
}
pers.append(percent_str.length() - pers.length(), ' ');
// calculate flops
std::string flps;
std::string op_name = ins->name();
auto nloc = op_name.find("::");
op_name.erase(op_name.begin(), op_name.begin() + nloc + 2);
if (contains(flops_funcs, op_name))
{
auto op_flop_func = flops_funcs.at(op_name);
auto inss = to_shapes(ins->inputs());
double flops = op_flop_func(inss);
flops /= avg;
// convert to GFlops
flops /= 1.0e6;
flps = std::to_string(flops);
auto floc = flps.find('.');
if (floc != std::string::npos)
{
flps.erase(flps.begin() + floc + 4, flps.end());
}
}
os << tms << pers << flps << std::endl;
}); });
os << std::endl; os << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment