Commit b621b28f authored by Paul's avatar Paul
Browse files

Simplify fuse_ck

parent fee874e3
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_CK_GEMM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_CK_GEMM_FUSION);
struct module; struct module;
...@@ -51,43 +50,6 @@ struct ck_gemm ...@@ -51,43 +50,6 @@ struct ck_gemm
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_int8
{
operation op = make_op("quant_dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm_int8"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
check_gemm_shape(input);
auto r = op.compute_shape({a, b});
if(mods.empty())
return r.with_type(migraphx::shape::int8_type);
return r.with_type(mods.front()->get_output_shapes().front().type());
}
};
MIGRAPHX_REGISTER_OP(ck_gemm_int8);
namespace { namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...@@ -107,7 +69,7 @@ struct find_ck_gemm_pointwise ...@@ -107,7 +69,7 @@ struct find_ck_gemm_pointwise
auto matcher() const auto matcher() const
{ {
auto gemm = auto gemm =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm"))); match::skip(match::name("contiguous"))(match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
} }
...@@ -123,7 +85,7 @@ struct find_ck_gemm_pointwise ...@@ -123,7 +85,7 @@ struct find_ck_gemm_pointwise
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if(ins->get_shape().type() != shape::half_type) if(not contains({shape::half_type, shape::int8_type, shape::int32_type}, ins->get_shape().type()))
return; return;
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
...@@ -140,49 +102,7 @@ struct find_ck_gemm_pointwise ...@@ -140,49 +102,7 @@ struct find_ck_gemm_pointwise
inputs.erase(gemm_it); inputs.erase(gemm_it);
inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end()); inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
mpm.get_module().replace_instruction(ins, ck_gemm{}, inputs, {pm}); mpm.get_module().replace_instruction(ins, ck_gemm{gemm_ins->get_operator()}, inputs, {pm});
}
};
struct find_ck_gemm_pointwise_int8
{
// Find a gemm followed by a pointwise operation.
auto matcher() const
{
auto gemm = match::skip(match::name("contiguous"))(
match::name("quant_dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto next_ins = std::next(ins);
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
if(gemm_idx != 0)
{
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + "_0", gemm_param->get_shape());
auto new_first_param =
pm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
pm->remove_instruction(gemm_param);
}
inputs.erase(gemm_it);
inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
mpm.get_module().replace_instruction(ins, ck_gemm_int8{}, inputs, {pm});
} }
}; };
...@@ -197,30 +117,14 @@ struct find_ck_gemm ...@@ -197,30 +117,14 @@ struct find_ck_gemm
} }
}; };
struct find_ck_gemm_int8
{
auto matcher() const { return match::name("quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm_int8{ins->get_operator()}, ins->inputs());
}
};
} // namespace } // namespace
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
{ {
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{})) if (enabled(MIGRAPHX_ENABLE_CK_GEMM{}))
{ {
match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm_pointwise_int8{});
}
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
{
match::find_matches(mpm, find_ck_gemm{}); match::find_matches(mpm, find_ck_gemm{});
match::find_matches(mpm, find_ck_gemm_int8{});
} }
} }
......
...@@ -222,7 +222,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -222,7 +222,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std::vector<std::string> names() const std::vector<std::string> names() const
{ {
return {"ck_gemm", "gpu::ck_gemm", "ck_gemm_int8", "gpu::ck_gemm_int8"}; return {"ck_gemm", "gpu::ck_gemm"};
} }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment