Commit 9c6ba1ed authored by Alan Turner's avatar Alan Turner
Browse files

Add disable envvar

parent 8edc10eb
......@@ -49,44 +49,6 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP(ck_gemm);
// struct ck_gemm_scale_bias_softmax_gemm
// {
// operation op = make_op("dot");
// template <class Self, class F>
// static auto reflect(Self& self, F f)
// {
// return pack(f(self.op, "op"));
// }
// std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
// void check_gemm_shape(const shape& s) const
// {
// if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
// MIGRAPHX_THROW("Invalid shape for ck_gemm_scale_bias_softmax_gemm");
// }
// shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
// {
// check_shapes{inputs, *this}.same_ndims();
// // if(mods.size() != 1)
// // MIGRAPHX_THROW("should have one submodule.");
// if(inputs.size() < 2)
// MIGRAPHX_THROW("should have at least two inputs.");
// auto a = inputs[0];
// auto b = inputs[1];
// auto b1 = inputs[2];
// for(const auto& input : inputs)
// {
// // std::cout << input << std::endl;
// check_gemm_shape(input);
// }
// return op.compute_shape({op.compute_shape({a, b}), b1});
// }
// };
// MIGRAPHX_REGISTER_OP(ck_gemm_scale_bias_softmax_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
......@@ -154,77 +116,10 @@ struct find_ck_gemm
}
};
struct find_ck_gemm_scale_bias_softmax_gemm
{
// auto matcher() const
// {
// auto gemm1 =
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto pw =
// match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
// auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(
// match::any_of[match::inputs()](softmax));
// }
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
// {
// std::cout << "Matched" << std::endl;
// auto ins = r.result;
// auto gemm2_ins = r.instructions["gemm2"];
// auto sm_ins = r.instructions["softmax"];
// auto pw_ins = r.instructions["scale_bias"];
// auto gemm1_ins = r.instructions["gemm1"];
// gemm2_ins->debug_print();
// sm_ins->debug_print();
// pw_ins->debug_print();
// gemm1_ins->debug_print();
// auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// // inputs.push_back(pw_ins->inputs().back()); // C
// mpm.get_module().replace_instruction(
// ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// }
// auto matcher() const
// {
// auto gemm1 =
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// }
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
// {
// std::cout << "Matched" << std::endl;
// auto ins = r.result;
// auto gemm2_ins = r.instructions["gemm2"];
// auto sm_ins = r.instructions["softmax"];
// auto gemm1_ins = r.instructions["gemm1"];
// gemm2_ins->debug_print();
// sm_ins->debug_print();
// gemm1_ins->debug_print();
// auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// mpm.get_module().replace_instruction(ins,
// ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// }
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
// mpm.get_module().debug_print();
// match::find_matches(mpm, find_ck_gemm_scale_bias_softmax_gemm{});
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
match::find_matches(mpm, find_ck_gemm_pointwise{});
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
......
......@@ -8,6 +8,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_CK_GEMM_SOFTMAX_GEMM);
struct module;
namespace gpu {
......@@ -91,6 +93,7 @@ struct find_gemm_softmax_gemm_gemm
void fuse_ck_gemm_softmax_gemm::apply(module_pass_manager& mpm) const
{
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_SOFTMAX_GEMM{}))
match::find_matches(mpm, find_gemm_softmax_gemm_gemm{});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment