Commit b222b257 authored by Paul's avatar Paul
Browse files

Interpolate tuning param

parent 51c11552
......@@ -137,6 +137,8 @@ struct instance
std::string str() const { return join_strings(params, ","); }
};
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
template <class F, class Action>
auto action_decorate(F f, Action action)
{
......@@ -154,6 +156,18 @@ static std::vector<tuning_entry> read_tuning(const std::string& s)
return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
}
static float matrix_distance(const shape& x, const shape& y)
{
if (x.type() != y.type())
return std::numeric_limits<float>::max();
if (transposed_matrix(x) != transposed_matrix(y))
return std::numeric_limits<float>::max();
auto sum_squared = std::inner_product(x.lens().rbegin(), x.lens().rbegin()+2, y.lens().rbegin(), 0, std::plus<>{}, [](auto a, auto b) {
return (a - b) * (a - b);
});
return std::sqrt(sum_squared);
}
static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
......@@ -164,14 +178,26 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
if(it == tuning.end())
{
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
return value_of(MIGRAPHX_CK_TUNING_VALUE{}, 4);
std::vector<std::pair<float, std::size_t>> w;
std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
if (inputs.size() < 3 or p.first.size() < 3)
MIGRAPHX_THROW("Invalid CK config");
auto avg_distance = std::inner_product(p.first.begin(), p.first.begin()+3, inputs.begin(), 0.0f, std::plus<>{}, [](const auto& x, const auto& y) {
return matrix_distance(x, y) / 3.0f;
});
return std::make_pair(avg_distance, p.second);
});
std::sort(w.begin(), w.end());
std::size_t default_value = 4;
if (not w.empty())
default_value = w.front().second;
return value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
}
return it->second;
}
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
static std::string get_layout(const shape& s)
{
return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment