Commit 74f21ca6 authored by Paul's avatar Paul
Browse files

Refactor

parent 674b3bac
......@@ -112,11 +112,13 @@ struct compile_plan
{
if(results.size() == 1)
return results.front();
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" << std::endl;
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
std::vector<double> times;
for(const auto& cr:results)
for(const auto& cr : results)
{
times.push_back(time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first);
times.push_back(
time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first);
}
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
return results[i];
......
......@@ -267,20 +267,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
bool can_fold_batch(const std::vector<shape>& inputs) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides();
return rank >= 3 and b_strides[rank - 3] == 0;
}
ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs, const value& v) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides();
bool can_fold_batch = rank >= 3 and b_strides[rank - 3] == 0;
auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2];
m = can_fold_batch ? m * batch_count : m;
m = can_fold_batch(inputs) ? m * batch_count : m;
auto n = c_shape.lens().back();
auto k = a_shape.lens().back();
......@@ -309,7 +317,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
cde_op = v.at("post").to<std::string>();
}
auto problem = ck::host::device_gemm_multiple_d::Problem{m,
return ck::host::device_gemm_multiple_d::Problem{m,
n,
k,
transA,
......@@ -323,6 +331,16 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ck_passthrough,
ck_passthrough,
cde_op};
}
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions("gfx90a");
......@@ -333,13 +351,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
hip_compile_options options;
options.additional_src_files = ck_headers();
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
options.output = c_shape;
options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs;
if(can_fold_batch)
if(can_fold_batch(inputs))
{
auto vinputs = inputs;
fold_batch_dims(vinputs[0]);
......@@ -363,7 +381,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
value create_settings(instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["kernel"] = "ck_gemm_kernel";
......@@ -375,9 +393,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
v["post"] = "ck_function_adaptor<post_ck_gemm>";
v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
return v;
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto shapes = to_shapes(ins->inputs());
return {compile_op(ctx, shapes, v),
return {compile_op(ctx, shapes, create_settings(ins, op)),
[=](module& m, instruction_ref ins2, const operation& code_object) {
if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment