Commit 00439b95 authored by Paul's avatar Paul
Browse files

Format

parent 3800d2b7
...@@ -163,7 +163,7 @@ struct compile_plan ...@@ -163,7 +163,7 @@ struct compile_plan
<< std::endl; << std::endl;
std::vector<double> times; std::vector<double> times;
times.reserve(results.size()); times.reserve(results.size());
for(const auto& cr : results) for(const auto& cr : results)
{ {
times.push_back( times.push_back(
time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first); time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first);
......
...@@ -268,10 +268,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -268,10 +268,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
bool can_fold_batch(const std::vector<shape>& inputs) const bool can_fold_batch(const std::vector<shape>& inputs) const
{ {
const auto& a_shape = inputs[0]; const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides(); auto b_strides = b_shape.strides();
return rank >= 3 and b_strides[rank - 3] == 0; return rank >= 3 and b_strides[rank - 3] == 0;
} }
...@@ -282,8 +282,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -282,8 +282,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2]; auto m = c_shape.lens()[rank - 2];
m = can_fold_batch(inputs) ? m * batch_count : m; m = can_fold_batch(inputs) ? m * batch_count : m;
...@@ -293,9 +293,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -293,9 +293,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const bool trans_a = transposed_matrix(a_shape); const bool trans_a = transposed_matrix(a_shape);
const bool trans_b = transposed_matrix(b_shape); const bool trans_b = transposed_matrix(b_shape);
const bool trans_e = transposed_matrix(c_shape); const bool trans_e = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape); const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape); const auto b_type = get_type(b_shape);
const auto e_type = get_type(c_shape); const auto e_type = get_type(c_shape);
std::vector<bool> ds_layout; std::vector<bool> ds_layout;
std::transform(inputs.begin() + 2, std::transform(inputs.begin() + 2,
inputs.end() - 1, inputs.end() - 1,
...@@ -333,10 +333,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -333,10 +333,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
const auto& a_shape = inputs[0]; const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 4); auto tuning_value = v.get("tuning_value", 4);
if(not v.contains("tuning_value")) if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, c_shape}); tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
...@@ -344,7 +344,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -344,7 +344,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto include_header = problem.GetIncludeHeader(); const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
const auto& solution = solutions.at(tuning_value); const auto& solution = solutions.at(tuning_value);
const auto template_str = solution.template_str; const auto template_str = solution.template_str;
const auto blocks_per_batch = solution.grid_size; const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size; const auto block_size = solution.block_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment