Commit 00439b95 authored by Paul's avatar Paul
Browse files

Format

parent 3800d2b7
......@@ -163,7 +163,7 @@ struct compile_plan
<< std::endl;
std::vector<double> times;
times.reserve(results.size());
for(const auto& cr : results)
for(const auto& cr : results)
{
times.push_back(
time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first);
......
......@@ -268,10 +268,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
bool can_fold_batch(const std::vector<shape>& inputs) const
{
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides();
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides();
return rank >= 3 and b_strides[rank - 3] == 0;
}
......@@ -282,8 +282,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back();
auto rank = a_shape.lens().size();
auto rank = a_shape.lens().size();
auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2];
m = can_fold_batch(inputs) ? m * batch_count : m;
......@@ -293,9 +293,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const bool trans_a = transposed_matrix(a_shape);
const bool trans_b = transposed_matrix(b_shape);
const bool trans_e = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape);
const auto e_type = get_type(c_shape);
const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape);
const auto e_type = get_type(c_shape);
std::vector<bool> ds_layout;
std::transform(inputs.begin() + 2,
inputs.end() - 1,
......@@ -333,10 +333,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 4);
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 4);
if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto batch_count = get_batch_count(c_shape);
......@@ -344,7 +344,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
const auto& solution = solutions.at(tuning_value);
const auto& solution = solutions.at(tuning_value);
const auto template_str = solution.template_str;
const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment