Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9c6ba1ed
Commit
9c6ba1ed
authored
Nov 18, 2022
by
Alan Turner
Browse files
Add disable envvar
parent
8edc10eb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
106 deletions
+4
-106
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+0
-105
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
+4
-1
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
9c6ba1ed
...
@@ -49,44 +49,6 @@ struct ck_gemm
...
@@ -49,44 +49,6 @@ struct ck_gemm
};
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
// struct ck_gemm_scale_bias_softmax_gemm
// {
// operation op = make_op("dot");
// template <class Self, class F>
// static auto reflect(Self& self, F f)
// {
// return pack(f(self.op, "op"));
// }
// std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
// void check_gemm_shape(const shape& s) const
// {
// if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
// MIGRAPHX_THROW("Invalid shape for ck_gemm_scale_bias_softmax_gemm");
// }
// shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
// {
// check_shapes{inputs, *this}.same_ndims();
// // if(mods.size() != 1)
// // MIGRAPHX_THROW("should have one submodule.");
// if(inputs.size() < 2)
// MIGRAPHX_THROW("should have at least two inputs.");
// auto a = inputs[0];
// auto b = inputs[1];
// auto b1 = inputs[2];
// for(const auto& input : inputs)
// {
// // std::cout << input << std::endl;
// check_gemm_shape(input);
// }
// return op.compute_shape({op.compute_shape({a, b}), b1});
// }
// };
// MIGRAPHX_REGISTER_OP(ck_gemm_scale_bias_softmax_gemm);
namespace
{
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
...
@@ -154,77 +116,10 @@ struct find_ck_gemm
...
@@ -154,77 +116,10 @@ struct find_ck_gemm
}
}
};
};
struct
find_ck_gemm_scale_bias_softmax_gemm
{
// auto matcher() const
// {
// auto gemm1 =
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto pw =
// match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
// auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(
// match::any_of[match::inputs()](softmax));
// }
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
// {
// std::cout << "Matched" << std::endl;
// auto ins = r.result;
// auto gemm2_ins = r.instructions["gemm2"];
// auto sm_ins = r.instructions["softmax"];
// auto pw_ins = r.instructions["scale_bias"];
// auto gemm1_ins = r.instructions["gemm1"];
// gemm2_ins->debug_print();
// sm_ins->debug_print();
// pw_ins->debug_print();
// gemm1_ins->debug_print();
// auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// // inputs.push_back(pw_ins->inputs().back()); // C
// mpm.get_module().replace_instruction(
// ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// }
// auto matcher() const
// {
// auto gemm1 =
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// }
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
// {
// std::cout << "Matched" << std::endl;
// auto ins = r.result;
// auto gemm2_ins = r.instructions["gemm2"];
// auto sm_ins = r.instructions["softmax"];
// auto gemm1_ins = r.instructions["gemm1"];
// gemm2_ins->debug_print();
// sm_ins->debug_print();
// gemm1_ins->debug_print();
// auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// mpm.get_module().replace_instruction(ins,
// ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// }
};
}
// namespace
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
// mpm.get_module().debug_print();
// match::find_matches(mpm, find_ck_gemm_scale_bias_softmax_gemm{});
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM_FUSION
{}))
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM_FUSION
{}))
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM
{}))
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM
{}))
...
...
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
View file @
9c6ba1ed
...
@@ -8,6 +8,8 @@
...
@@ -8,6 +8,8 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_CK_GEMM_SOFTMAX_GEMM
);
struct
module
;
struct
module
;
namespace
gpu
{
namespace
gpu
{
...
@@ -91,7 +93,8 @@ struct find_gemm_softmax_gemm_gemm
...
@@ -91,7 +93,8 @@ struct find_gemm_softmax_gemm_gemm
void
fuse_ck_gemm_softmax_gemm
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_ck_gemm_softmax_gemm
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm_gemm
{});
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM_SOFTMAX_GEMM
{}))
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm_gemm
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment