"src/vscode:/vscode.git/clone" did not exist on "041c3f0e2d1f0020f8c30e422094a07bb88b9b2d"
Commit 36154263 authored by Paul's avatar Paul
Browse files

Format

parent af0d45a5
...@@ -149,8 +149,8 @@ struct find_ck_gemm_pointwise_int8 ...@@ -149,8 +149,8 @@ struct find_ck_gemm_pointwise_int8
// Find a gemm followed by a pointwise operation. // Find a gemm followed by a pointwise operation.
auto matcher() const auto matcher() const
{ {
auto gemm = auto gemm = match::skip(match::name("contiguous"))(
match::skip(match::name("contiguous"))(match::name("quant_dot")(is_ck_gemm().bind("gemm"))); match::name("quant_dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
} }
...@@ -162,7 +162,7 @@ struct find_ck_gemm_pointwise_int8 ...@@ -162,7 +162,7 @@ struct find_ck_gemm_pointwise_int8
auto next_ins = std::next(ins); auto next_ins = std::next(ins);
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
......
...@@ -229,7 +229,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -229,7 +229,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
s = shape{s.type(), {m1, m2}}; s = shape{s.type(), {m1, m2}};
} }
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm", "ck_gemm_int8", "gpu::ck_gemm_int8"}; } std::vector<std::string> names() const
{
return {"ck_gemm", "gpu::ck_gemm", "ck_gemm_int8", "gpu::ck_gemm_int8"};
}
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{ {
...@@ -272,24 +275,21 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -272,24 +275,21 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
cde_op = v.at("post").to<std::string>(); cde_op = v.at("post").to<std::string>();
} }
auto problem = ck::host::device_gemm_multiple_d::Problem{m,
n,
auto problem = ck::host::device_gemm_multiple_d::Problem{ k,
m, transA,
n, transB,
k, transE,
transA, ds_layout,
transB, a_type,
transE, b_type,
ds_layout, e_type,
a_type, ds_type,
b_type, ck_passthrough,
e_type, ck_passthrough,
ds_type, cde_op};
ck_passthrough,
ck_passthrough,
cde_op};
const auto include_header = problem.GetIncludeHeader(); const auto include_header = problem.GetIncludeHeader();
const auto ck_headers = ck::host::GetHeaders(); const auto ck_headers = ck::host::GetHeaders();
...@@ -345,14 +345,17 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -345,14 +345,17 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
} }
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
return {compile_op(ctx, shapes, v), [=](module& m, instruction_ref ins2, const operation& code_object) { return {compile_op(ctx, shapes, v),
if(enabled(MIGRAPHX_LOG_CK_GEMM{})) [=](module& m, instruction_ref ins2, const operation& code_object) {
{ if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())}; {
std::cout << "ck_gemm: " << to_json_string(to_value(gemm_shapes)) << std::endl; std::vector<shape> gemm_shapes{
} shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())};
m.replace_instruction(ins2, code_object, ins2->inputs()); std::cout << "ck_gemm: " << to_json_string(to_value(gemm_shapes))
}}; << std::endl;
}
m.replace_instruction(ins2, code_object, ins2->inputs());
}};
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment