"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "de86193cab071866cfbd715e26bbdd74784d5056"
Commit 74f21ca6 authored by Paul's avatar Paul
Browse files

Refactor

parent 674b3bac
...@@ -112,11 +112,13 @@ struct compile_plan ...@@ -112,11 +112,13 @@ struct compile_plan
{ {
if(results.size() == 1) if(results.size() == 1)
return results.front(); return results.front();
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" << std::endl; std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
std::vector<double> times; std::vector<double> times;
for(const auto& cr:results) for(const auto& cr : results)
{ {
times.push_back(time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first); times.push_back(
time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first);
} }
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
return results[i]; return results[i];
......
...@@ -267,20 +267,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -267,20 +267,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const bool can_fold_batch(const std::vector<shape>& inputs) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides();
return rank >= 3 and b_strides[rank - 3] == 0;
}
ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs, const value& v) const
{ {
auto a_shape = inputs[0]; auto a_shape = inputs[0];
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto c_shape = inputs.back(); auto c_shape = inputs.back();
auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides(); auto b_strides = b_shape.strides();
bool can_fold_batch = rank >= 3 and b_strides[rank - 3] == 0;
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2]; auto m = c_shape.lens()[rank - 2];
m = can_fold_batch ? m * batch_count : m; m = can_fold_batch(inputs) ? m * batch_count : m;
auto n = c_shape.lens().back(); auto n = c_shape.lens().back();
auto k = a_shape.lens().back(); auto k = a_shape.lens().back();
...@@ -309,7 +317,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -309,7 +317,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
cde_op = v.at("post").to<std::string>(); cde_op = v.at("post").to<std::string>();
} }
auto problem = ck::host::device_gemm_multiple_d::Problem{m, return ck::host::device_gemm_multiple_d::Problem{m,
n, n,
k, k,
transA, transA,
...@@ -323,6 +331,16 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -323,6 +331,16 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ck_passthrough, ck_passthrough,
ck_passthrough, ck_passthrough,
cde_op}; cde_op};
}
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
const auto include_header = problem.GetIncludeHeader(); const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions("gfx90a"); const auto solutions = problem.GetSolutions("gfx90a");
...@@ -333,13 +351,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -333,13 +351,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
hip_compile_options options; hip_compile_options options;
options.additional_src_files = ck_headers(); options.additional_src_files = ck_headers();
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch; auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
options.output = c_shape; options.output = c_shape;
options.kernel_name = v.get("kernel", "ck_gemm_kernel"); options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
if(can_fold_batch) if(can_fold_batch(inputs))
{ {
auto vinputs = inputs; auto vinputs = inputs;
fold_batch_dims(vinputs[0]); fold_batch_dims(vinputs[0]);
...@@ -363,7 +381,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -363,7 +381,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const value create_settings(instruction_ref ins, const operation& op) const
{ {
auto v = op.to_value(); auto v = op.to_value();
v["kernel"] = "ck_gemm_kernel"; v["kernel"] = "ck_gemm_kernel";
...@@ -375,9 +393,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -375,9 +393,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
v["post"] = "ck_function_adaptor<post_ck_gemm>"; v["post"] = "ck_function_adaptor<post_ck_gemm>";
v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel";
} }
return v;
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
return {compile_op(ctx, shapes, v), return {compile_op(ctx, shapes, create_settings(ins, op)),
[=](module& m, instruction_ref ins2, const operation& code_object) { [=](module& m, instruction_ref ins2, const operation& code_object) {
if(enabled(MIGRAPHX_LOG_CK_GEMM{})) if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment