Commit c393f233 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 37939805
...@@ -218,7 +218,7 @@ static bool is_mul_module(const module& m) ...@@ -218,7 +218,7 @@ static bool is_mul_module(const module& m)
} }
else if(ins.name() == "mul") else if(ins.name() == "mul")
{ {
return true; return true;
} }
} }
return false; return false;
...@@ -230,7 +230,7 @@ struct find_ck_gemm_softmax_gemm ...@@ -230,7 +230,7 @@ struct find_ck_gemm_softmax_gemm
{ {
auto gemm1 = auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale"); auto mul = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](mul)).bind("softmax"); auto softmax = match::name("softmax")(match::any_of[match::inputs()](mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))( return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax)); match::any_of[match::inputs()](softmax));
...@@ -243,21 +243,23 @@ struct find_ck_gemm_softmax_gemm ...@@ -243,21 +243,23 @@ struct find_ck_gemm_softmax_gemm
auto gemm1_ins = r.instructions["gemm1"]; auto gemm1_ins = r.instructions["gemm1"];
auto scale_ins = r.instructions["scale"]; auto scale_ins = r.instructions["scale"];
if (scale_ins->module_inputs().size() != 1 or not is_mul_module(*scale_ins->module_inputs().front())) if(scale_ins->module_inputs().size() != 1 or
not is_mul_module(*scale_ins->module_inputs().front()))
return; return;
if (not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type())) if(not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
return; return;
double scale = 1.0; double scale = 1.0;
for (auto& in: scale_ins->inputs()) for(auto& in : scale_ins->inputs())
{ {
if (in->can_eval()) if(in->can_eval())
{ {
in->get_literal().visit([&](const auto s) { in->get_literal().visit([&](const auto s) {
if (std::all_of( if(std::all_of(s.begin() + 1, s.end(), [&](auto v) {
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); })) return float_equal(v, s.front());
}))
scale = s.front(); scale = s.front();
else else
return; return;
}); });
} }
......
...@@ -352,7 +352,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -352,7 +352,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
const auto& a_shape = inputs[0]; const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
const auto& b1_shape = inputs[2]; const auto& b1_shape = inputs[2];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 4); auto tuning_value = v.get("tuning_value", 4);
if(not v.contains("tuning_value")) if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, b1_shape, c_shape}); tuning_value = get_tuning_for({a_shape, b_shape, b1_shape, c_shape});
...@@ -399,7 +399,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -399,7 +399,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{"blocks_per_batch", to_string(blocks_per_batch)}, {"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}}); {"kernel", options.kernel_name}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment