Commit 3800d2b7 authored by Paul's avatar Paul
Browse files

More tidy fixes

parent 8fa4cd1b
...@@ -162,7 +162,8 @@ struct compile_plan ...@@ -162,7 +162,8 @@ struct compile_plan
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl; << std::endl;
std::vector<double> times; std::vector<double> times;
for(const auto& cr : results) times.reserve(results.size());
for(const auto& cr : results)
{ {
times.push_back( times.push_back(
time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first); time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first);
......
...@@ -58,9 +58,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -58,9 +58,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
if(a.lens().back() > 2048) return a.lens().back() <= 2048;
return false;
return true;
} }
struct find_ck_gemm_pointwise struct find_ck_gemm_pointwise
......
...@@ -27,16 +27,15 @@ ...@@ -27,16 +27,15 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp> #include <migraphx/env.hpp>
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/gpu/compile_gen.hpp> #include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/env.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
#include "ck/host/device_gemm_multiple_d.hpp" #include "ck/host/device_gemm_multiple_d.hpp"
...@@ -269,8 +268,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -269,8 +268,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
bool can_fold_batch(const std::vector<shape>& inputs) const bool can_fold_batch(const std::vector<shape>& inputs) const
{ {
auto a_shape = inputs[0]; const auto& a_shape = inputs[0];
auto b_shape = inputs[1]; const auto& b_shape = inputs[1];
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides(); auto b_strides = b_shape.strides();
return rank >= 3 and b_strides[rank - 3] == 0; return rank >= 3 and b_strides[rank - 3] == 0;
...@@ -279,22 +278,21 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -279,22 +278,21 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs, ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs,
const value& v) const const value& v) const
{ {
auto a_shape = inputs[0]; const auto& a_shape = inputs[0];
auto b_shape = inputs[1]; const auto& b_shape = inputs[1];
auto c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides();
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2]; auto m = c_shape.lens()[rank - 2];
m = can_fold_batch(inputs) ? m * batch_count : m; m = can_fold_batch(inputs) ? m * batch_count : m;
auto n = c_shape.lens().back(); auto n = c_shape.lens().back();
auto k = a_shape.lens().back(); auto k = a_shape.lens().back();
const bool transA = transposed_matrix(a_shape); const bool trans_a = transposed_matrix(a_shape);
const bool transB = transposed_matrix(b_shape); const bool trans_b = transposed_matrix(b_shape);
const bool transE = transposed_matrix(c_shape); const bool trans_e = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape); const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape); const auto b_type = get_type(b_shape);
const auto e_type = get_type(c_shape); const auto e_type = get_type(c_shape);
...@@ -320,9 +318,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -320,9 +318,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return ck::host::device_gemm_multiple_d::Problem{m, return ck::host::device_gemm_multiple_d::Problem{m,
n, n,
k, k,
transA, trans_a,
transB, trans_b,
transE, trans_e,
ds_layout, ds_layout,
a_type, a_type,
b_type, b_type,
...@@ -335,9 +333,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -335,9 +333,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
auto a_shape = inputs[0]; const auto& a_shape = inputs[0];
auto b_shape = inputs[1]; const auto& b_shape = inputs[1];
auto c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 4); auto tuning_value = v.get("tuning_value", 4);
if(not v.contains("tuning_value")) if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, c_shape}); tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
...@@ -346,7 +344,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -346,7 +344,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto include_header = problem.GetIncludeHeader(); const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
const auto solution = solutions.at(tuning_value); const auto& solution = solutions.at(tuning_value);
const auto template_str = solution.template_str; const auto template_str = solution.template_str;
const auto blocks_per_batch = solution.grid_size; const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size; const auto block_size = solution.block_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment