Commit 97456425 authored by Paul's avatar Paul
Browse files

Format

parent b222b257
...@@ -158,13 +158,16 @@ static std::vector<tuning_entry> read_tuning(const std::string& s) ...@@ -158,13 +158,16 @@ static std::vector<tuning_entry> read_tuning(const std::string& s)
static float matrix_distance(const shape& x, const shape& y) static float matrix_distance(const shape& x, const shape& y)
{ {
if (x.type() != y.type()) if(x.type() != y.type())
return std::numeric_limits<float>::max(); return std::numeric_limits<float>::max();
if (transposed_matrix(x) != transposed_matrix(y)) if(transposed_matrix(x) != transposed_matrix(y))
return std::numeric_limits<float>::max(); return std::numeric_limits<float>::max();
auto sum_squared = std::inner_product(x.lens().rbegin(), x.lens().rbegin()+2, y.lens().rbegin(), 0, std::plus<>{}, [](auto a, auto b) { auto sum_squared = std::inner_product(x.lens().rbegin(),
return (a - b) * (a - b); x.lens().rbegin() + 2,
}); y.lens().rbegin(),
0,
std::plus<>{},
[](auto a, auto b) { return (a - b) * (a - b); });
return std::sqrt(sum_squared); return std::sqrt(sum_squared);
} }
...@@ -180,16 +183,20 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs) ...@@ -180,16 +183,20 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl; std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
std::vector<std::pair<float, std::size_t>> w; std::vector<std::pair<float, std::size_t>> w;
std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) { std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
if (inputs.size() < 3 or p.first.size() < 3) if(inputs.size() < 3 or p.first.size() < 3)
MIGRAPHX_THROW("Invalid CK config"); MIGRAPHX_THROW("Invalid CK config");
auto avg_distance = std::inner_product(p.first.begin(), p.first.begin()+3, inputs.begin(), 0.0f, std::plus<>{}, [](const auto& x, const auto& y) { auto avg_distance = std::inner_product(
return matrix_distance(x, y) / 3.0f; p.first.begin(),
}); p.first.begin() + 3,
inputs.begin(),
0.0f,
std::plus<>{},
[](const auto& x, const auto& y) { return matrix_distance(x, y) / 3.0f; });
return std::make_pair(avg_distance, p.second); return std::make_pair(avg_distance, p.second);
}); });
std::sort(w.begin(), w.end()); std::sort(w.begin(), w.end());
std::size_t default_value = 4; std::size_t default_value = 4;
if (not w.empty()) if(not w.empty())
default_value = w.front().second; default_value = w.front().second;
return value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value); return value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment