Commit 47bde8f7 authored by Paul's avatar Paul
Browse files

Merge branch 'ck-integration-tuning' of...

Merge branch 'ck-integration-tuning' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into ck-integration-tuning
parents 7099811c 3d8c71fa
#ifndef MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#include <migraphx/config.hpp>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
std::size_t hash_value(const T& v)
{
return std::hash<T>{}(v);
}
template <class T>
void hash_combine(std::size_t& seed, const T& v)
{
seed ^= hash_value(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
...@@ -392,8 +392,8 @@ struct value ...@@ -392,8 +392,8 @@ struct value
return; \ return; \
} }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, ) MIGRAPHX_VALUE_GENERATE_CASE_VALUE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, ) MIGRAPHX_VALUE_GENERATE_CASE_VALUE(object, )
} }
MIGRAPHX_THROW("Unknown type"); MIGRAPHX_THROW("Unknown type");
} }
...@@ -461,6 +461,8 @@ struct value ...@@ -461,6 +461,8 @@ struct value
friend std::ostream& operator<<(std::ostream& os, const value& d); friend std::ostream& operator<<(std::ostream& os, const value& d);
std::size_t hash() const;
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
type_t get_type() const; type_t get_type() const;
...@@ -481,4 +483,15 @@ struct value ...@@ -481,4 +483,15 @@ struct value
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::value>
{
using argument_type = migraphx::value;
using result_type = std::size_t;
result_type operator()(const migraphx::value& x) const noexcept { return x.hash(); }
};
} // namespace std
#endif #endif
...@@ -77,6 +77,30 @@ struct compiled_result ...@@ -77,6 +77,30 @@ struct compiled_result
instruction_ref ins; instruction_ref ins;
}; };
struct problem_cache
{
bool has(const std::string& name, const value& problem) const
{
return contains(cache, create_key(name, problem));
}
void insert(const std::string& name, const value& problem, const value& solution)
{
cache[create_key(name, problem)] = solution;
}
optional<value> get(const std::string& name, const value& problem) const
{
auto it = cache.find(create_key(name, problem));
if(it == cache.end())
return nullopt;
return it->second;
}
static value create_key(const std::string& name, const value& problem)
{
return {{"name", name}, {"problem", problem}};
}
std::unordered_map<value, value> cache;
};
struct compile_plan struct compile_plan
{ {
context* ctx; context* ctx;
...@@ -86,9 +110,22 @@ struct compile_plan ...@@ -86,9 +110,22 @@ struct compile_plan
std::vector<compiled_result> results = {}; std::vector<compiled_result> results = {};
void update_config() { config = get_tuning_config(*ctx, ins, preop); } void update_config() { config = get_tuning_config(*ctx, ins, preop); }
template <class Vector> template <class Vector>
void add_compiles(Vector& compiles) void add_compiles(Vector& compiles, problem_cache& pc)
{ {
if(config.has_value()) if(config.has_value())
{
const auto& problem = config.value().problem;
if(auto sol = pc.get(preop.name(), problem))
{
auto solution = sol.value();
// No solution yet until benchmarked so skip for now
if(solution.empty())
return;
compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
}
else
{ {
const auto& solutions = config.value().solutions; const auto& solutions = config.value().solutions;
results.resize(solutions.size()); results.resize(solutions.size());
...@@ -100,6 +137,7 @@ struct compile_plan ...@@ -100,6 +137,7 @@ struct compile_plan
}); });
} }
} }
}
else else
{ {
results.resize(1); results.resize(1);
...@@ -108,7 +146,7 @@ struct compile_plan ...@@ -108,7 +146,7 @@ struct compile_plan
}); });
} }
} }
const compiled_result& benchmark() const const compiled_result& benchmark(problem_cache& pc) const
{ {
if(results.empty()) if(results.empty())
MIGRAPHX_THROW("No configs to tune"); MIGRAPHX_THROW("No configs to tune");
...@@ -123,11 +161,12 @@ struct compile_plan ...@@ -123,11 +161,12 @@ struct compile_plan
time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first); time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first);
} }
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
pc.insert(preop.name(), config.value().problem, config.value().solutions[i]);
return results[i]; return results[i];
} }
void replace(module& m) const void replace(module& m, problem_cache& pc) const
{ {
const auto& cr = benchmark(); const auto& cr = benchmark(pc);
cr.replace.replace(m, cr.ins); cr.replace.replace(m, cr.ins);
} }
}; };
...@@ -140,32 +179,64 @@ void par_compile(std::size_t n, F f) ...@@ -140,32 +179,64 @@ void par_compile(std::size_t n, F f)
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f); par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
} }
void compile_ops::apply(module& m) const struct compile_manager
{ {
problem_cache pc;
std::vector<compile_plan> cps; std::vector<compile_plan> cps;
// Find all precompile opes
for(auto ins : iterator_for(m)) template <class... Ts>
void add_plan(Ts&&... xs)
{ {
if(ins->name() != "gpu::precompile_op") cps.push_back({std::forward<Ts>(xs)...});
continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op;
cps.push_back({ctx, preop, ins});
} }
// Get the tuning configs for all ops
void update_configs()
{
par_compile(cps.size(), [&](auto i) { cps[i].update_config(); }); par_compile(cps.size(), [&](auto i) { cps[i].update_config(); });
// Compile everything in parallel }
void compile(module& m)
{
std::vector<std::function<void()>> compiles; std::vector<std::function<void()>> compiles;
for(auto& cp : cps) for(auto& cp : cps)
{ {
cp.add_compiles(compiles); cp.add_compiles(compiles, pc);
} }
par_compile(compiles.size(), [&](auto i) { compiles[i](); }); par_compile(compiles.size(), [&](auto i) { compiles[i](); });
// Replace and/or benchmark // Replace and/or benchmark
for(const auto& cp : cps) for(const auto& cp : cps)
{ {
cp.replace(m); if(cp.results.empty())
continue;
cp.replace(m, pc);
}
// Remove compile_plan already executed
cps.erase(std::remove_if(cps.begin(),
cps.end(),
[](const auto& cp) { return not cp.results.empty(); }),
cps.end());
}
};
void compile_ops::apply(module& m) const
{
compile_manager cm;
problem_cache pc;
std::vector<compile_plan> cps;
// Find all precompile opes
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::precompile_op")
continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op;
cm.add_plan(ctx, preop, ins);
} }
cm.update_configs();
cm.compile(m);
// Compile already tuned configs
cm.compile(m);
} }
} // namespace gpu } // namespace gpu
......
...@@ -53,6 +53,9 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -53,6 +53,9 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{ {
if(ins->name() != "dot" and ins->name() != "quant_dot") if(ins->name() != "dot" and ins->name() != "quant_dot")
return false; return false;
if(not contains({shape::half_type, shape::int8_type, shape::int32_type},
ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
if(a.lens().back() > 2048) if(a.lens().back() > 2048)
...@@ -82,9 +85,6 @@ struct find_ck_gemm_pointwise ...@@ -82,9 +85,6 @@ struct find_ck_gemm_pointwise
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if(not contains({shape::half_type, shape::int8_type, shape::int32_type},
ins->get_shape().type()))
return;
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
auto first_param = pm->get_parameter(names[0]); auto first_param = pm->get_parameter(names[0]);
......
...@@ -426,6 +426,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -426,6 +426,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
tc.solutions.resize(solutions.size()); tc.solutions.resize(solutions.size());
std::iota(tc.solutions.begin(), tc.solutions.end(), 0); std::iota(tc.solutions.begin(), tc.solutions.end(), 0);
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
tc.problem = to_value(shapes);
return tc; return tc;
} }
}; };
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/optional.hpp> #include <migraphx/optional.hpp>
#include <migraphx/hash.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -519,6 +520,35 @@ std::ostream& operator<<(std::ostream& os, const value& d) ...@@ -519,6 +520,35 @@ std::ostream& operator<<(std::ostream& os, const value& d)
return os; return os;
} }
template <class T>
std::size_t value_hash(const std::string& key, const T& x)
{
std::size_t h = hash_value(key);
hash_combine(h, x);
return h;
}
std::size_t value_hash(const std::string& key, const std::vector<value>& x)
{
std::size_t h = hash_value(key);
for(const auto& v : x)
hash_combine(h, v);
return h;
}
std::size_t value_hash(const std::string& key, const value::binary& x)
{
std::size_t h = hash_value(key);
for(const auto& v : x)
hash_combine(h, v);
return h;
}
std::size_t value::hash() const
{
std::size_t h = 0;
this->visit_value([&](const auto& a) { h = value_hash(this->get_key(), a); });
return h;
}
void value::debug_print(bool show_type) const void value::debug_print(bool show_type) const
{ {
if(show_type) if(show_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment