Commit 07b8f71c authored by Paul's avatar Paul
Browse files

Format

parent bd42c933
......@@ -84,15 +84,9 @@ struct instance
std::vector<std::string> params;
static const std::size_t block_size_index = 15;
std::size_t int_at(std::size_t i) const
{
return std::stoull(params[i]);
}
std::size_t int_at(std::size_t i) const { return std::stoull(params[i]); }
std::size_t get_block_size() const
{
return int_at(block_size_index);
}
std::size_t get_block_size() const { return int_at(block_size_index); }
std::size_t get_pb(std::size_t i) const
{
......@@ -103,14 +97,13 @@ struct instance
std::array<std::size_t, 3> get_pad(const std::array<std::size_t, 3>& config) const
{
std::array<std::size_t, 3> result{};
for(auto i:range(config.size()))
for(auto i : range(config.size()))
{
result[i] = int_div_ceil(config[i], get_pb(i)) * get_pb(i) - config[i];
}
return result;
}
std::size_t get_grid_size(const std::array<std::size_t, 3>& config) const
{
return int_div_ceil(config[0], get_pb(1)) * int_div_ceil(config[1], get_pb(1));
......@@ -140,10 +133,7 @@ struct instance
params[13] = s;
}
std::string str() const
{
return join_strings(params, ",");
}
std::string str() const { return join_strings(params, ","); }
};
template <class F, class Action>
......@@ -210,10 +200,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto c_shape = inputs.back();
std::array<char, 3> keys{'M', 'N', 'K'};
std::array<std::size_t, 3> config{c_shape.lens().front(), c_shape.lens().back(), a_shape.lens().back()};
std::array<std::size_t, 3> config{
c_shape.lens().front(), c_shape.lens().back(), a_shape.lens().back()};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
......@@ -228,18 +219,17 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto padding = ip.get_pad(config);
std::string gemm_type;
for(auto i:range(padding.size()))
for(auto i : range(padding.size()))
{
if (padding[i] != 0)
if(padding[i] != 0)
gemm_type += keys[i];
}
if (gemm_type.empty())
if(gemm_type.empty())
gemm_type = "Default";
else
gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
hip_compile_options options;
auto block_size = ip.get_block_size();
auto grid_size = ip.get_grid_size(config);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment