Commit 47bde8f7 authored by Paul's avatar Paul
Browse files

Merge branch 'ck-integration-tuning' of...

Merge branch 'ck-integration-tuning' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into ck-integration-tuning
parents 7099811c 3d8c71fa
#ifndef MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#include <migraphx/config.hpp>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
std::size_t hash_value(const T& v)
{
return std::hash<T>{}(v);
}
template <class T>
void hash_combine(std::size_t& seed, const T& v)
{
seed ^= hash_value(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
......@@ -392,8 +392,8 @@ struct value
return; \
}
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, )
MIGRAPHX_VALUE_GENERATE_CASE_VALUE(array, )
MIGRAPHX_VALUE_GENERATE_CASE_VALUE(object, )
}
MIGRAPHX_THROW("Unknown type");
}
......@@ -461,6 +461,8 @@ struct value
friend std::ostream& operator<<(std::ostream& os, const value& d);
std::size_t hash() const;
void debug_print(bool show_type = false) const;
type_t get_type() const;
......@@ -481,4 +483,15 @@ struct value
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace std {
template <>
struct hash<migraphx::value>
{
using argument_type = migraphx::value;
using result_type = std::size_t;
result_type operator()(const migraphx::value& x) const noexcept { return x.hash(); }
};
} // namespace std
#endif
......@@ -77,6 +77,30 @@ struct compiled_result
instruction_ref ins;
};
struct problem_cache
{
bool has(const std::string& name, const value& problem) const
{
return contains(cache, create_key(name, problem));
}
void insert(const std::string& name, const value& problem, const value& solution)
{
cache[create_key(name, problem)] = solution;
}
optional<value> get(const std::string& name, const value& problem) const
{
auto it = cache.find(create_key(name, problem));
if(it == cache.end())
return nullopt;
return it->second;
}
static value create_key(const std::string& name, const value& problem)
{
return {{"name", name}, {"problem", problem}};
}
std::unordered_map<value, value> cache;
};
struct compile_plan
{
context* ctx;
......@@ -86,9 +110,22 @@ struct compile_plan
std::vector<compiled_result> results = {};
void update_config() { config = get_tuning_config(*ctx, ins, preop); }
template <class Vector>
void add_compiles(Vector& compiles)
void add_compiles(Vector& compiles, problem_cache& pc)
{
if(config.has_value())
{
const auto& problem = config.value().problem;
if(auto sol = pc.get(preop.name(), problem))
{
auto solution = sol.value();
// No solution yet until benchmarked so skip for now
if(solution.empty())
return;
compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
}
else
{
const auto& solutions = config.value().solutions;
results.resize(solutions.size());
......@@ -100,6 +137,7 @@ struct compile_plan
});
}
}
}
else
{
results.resize(1);
......@@ -108,7 +146,7 @@ struct compile_plan
});
}
}
const compiled_result& benchmark() const
const compiled_result& benchmark(problem_cache& pc) const
{
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
......@@ -123,11 +161,12 @@ struct compile_plan
time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first);
}
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
pc.insert(preop.name(), config.value().problem, config.value().solutions[i]);
return results[i];
}
void replace(module& m) const
void replace(module& m, problem_cache& pc) const
{
const auto& cr = benchmark();
const auto& cr = benchmark(pc);
cr.replace.replace(m, cr.ins);
}
};
......@@ -140,32 +179,64 @@ void par_compile(std::size_t n, F f)
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
}
void compile_ops::apply(module& m) const
struct compile_manager
{
problem_cache pc;
std::vector<compile_plan> cps;
// Find all precompile opes
for(auto ins : iterator_for(m))
template <class... Ts>
void add_plan(Ts&&... xs)
{
if(ins->name() != "gpu::precompile_op")
continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op;
cps.push_back({ctx, preop, ins});
cps.push_back({std::forward<Ts>(xs)...});
}
// Get the tuning configs for all ops
void update_configs()
{
par_compile(cps.size(), [&](auto i) { cps[i].update_config(); });
// Compile everything in parallel
}
void compile(module& m)
{
std::vector<std::function<void()>> compiles;
for(auto& cp : cps)
{
cp.add_compiles(compiles);
cp.add_compiles(compiles, pc);
}
par_compile(compiles.size(), [&](auto i) { compiles[i](); });
// Replace and/or benchmark
for(const auto& cp : cps)
{
cp.replace(m);
if(cp.results.empty())
continue;
cp.replace(m, pc);
}
// Remove compile_plan already executed
cps.erase(std::remove_if(cps.begin(),
cps.end(),
[](const auto& cp) { return not cp.results.empty(); }),
cps.end());
}
};
void compile_ops::apply(module& m) const
{
compile_manager cm;
problem_cache pc;
std::vector<compile_plan> cps;
// Find all precompile opes
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::precompile_op")
continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op;
cm.add_plan(ctx, preop, ins);
}
cm.update_configs();
cm.compile(m);
// Compile already tuned configs
cm.compile(m);
}
} // namespace gpu
......
......@@ -53,6 +53,9 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(not contains({shape::half_type, shape::int8_type, shape::int32_type},
ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
if(a.lens().back() > 2048)
......@@ -82,9 +85,6 @@ struct find_ck_gemm_pointwise
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
if(not contains({shape::half_type, shape::int8_type, shape::int32_type},
ins->get_shape().type()))
return;
if(gemm_idx != 0)
{
auto first_param = pm->get_parameter(names[0]);
......
......@@ -426,6 +426,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
tc.solutions.resize(solutions.size());
std::iota(tc.solutions.begin(), tc.solutions.end(), 0);
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
tc.problem = to_value(shapes);
return tc;
}
};
......
......@@ -28,6 +28,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/hash.hpp>
#include <unordered_map>
#include <utility>
......@@ -519,6 +520,35 @@ std::ostream& operator<<(std::ostream& os, const value& d)
return os;
}
template <class T>
std::size_t value_hash(const std::string& key, const T& x)
{
std::size_t h = hash_value(key);
hash_combine(h, x);
return h;
}
std::size_t value_hash(const std::string& key, const std::vector<value>& x)
{
std::size_t h = hash_value(key);
for(const auto& v : x)
hash_combine(h, v);
return h;
}
std::size_t value_hash(const std::string& key, const value::binary& x)
{
std::size_t h = hash_value(key);
for(const auto& v : x)
hash_combine(h, v);
return h;
}
std::size_t value::hash() const
{
std::size_t h = 0;
this->visit_value([&](const auto& a) { h = value_hash(this->get_key(), a); });
return h;
}
void value::debug_print(bool show_type) const
{
if(show_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment