"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "742e6a4be6b53f74dced599b943082e18330afb7"
Commit c393f233 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 37939805
......@@ -218,7 +218,7 @@ static bool is_mul_module(const module& m)
}
else if(ins.name() == "mul")
{
return true;
return true;
}
}
return false;
......@@ -230,7 +230,7 @@ struct find_ck_gemm_softmax_gemm
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto mul = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
......@@ -243,21 +243,23 @@ struct find_ck_gemm_softmax_gemm
auto gemm1_ins = r.instructions["gemm1"];
auto scale_ins = r.instructions["scale"];
if (scale_ins->module_inputs().size() != 1 or not is_mul_module(*scale_ins->module_inputs().front()))
if(scale_ins->module_inputs().size() != 1 or
not is_mul_module(*scale_ins->module_inputs().front()))
return;
if (not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
if(not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
return;
double scale = 1.0;
for (auto& in: scale_ins->inputs())
for(auto& in : scale_ins->inputs())
{
if (in->can_eval())
if(in->can_eval())
{
in->get_literal().visit([&](const auto s) {
if (std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
if(std::all_of(s.begin() + 1, s.end(), [&](auto v) {
return float_equal(v, s.front());
}))
scale = s.front();
else
else
return;
});
}
......
......@@ -352,7 +352,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
const auto& b1_shape = inputs[2];
const auto& c_shape = inputs.back();
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 4);
if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, b1_shape, c_shape});
......@@ -399,7 +399,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
return compile_hip_code_object(src, options);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment