Commit 349249c1 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-contiguous' into bert-opt

parents 0ad73695 25fcef27
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <cmath> #include <cmath>
#include <set> #include <set>
...@@ -989,9 +990,43 @@ struct find_commutative_broadcast ...@@ -989,9 +990,43 @@ struct find_commutative_broadcast
} }
}; };
struct find_contiguous
{
auto matcher() const { return match::name("gpu::contiguous"); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
m.replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(make_op("contiguous"))}}),
ins->inputs());
}
};
struct find_contiguous_pointwise
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(precompile_name("pointwise")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto pw = ins->inputs().front();
auto alloc = ins->inputs().back();
auto args = pw->inputs();
args.back() = alloc;
m.replace_instruction(ins, pw->get_operator(), args, pw->module_inputs());
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math});
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_triadd{}); match::find_matches(m, find_triadd{});
match::find_matches(m, match::find_matches(m,
...@@ -1013,6 +1048,7 @@ void fuse_ops::apply(module& m) const ...@@ -1013,6 +1048,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add{}, find_gemm_add{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -53,7 +53,7 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -53,7 +53,7 @@ static std::vector<std::string> get_op_names(const module& m)
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
static std::size_t oversubscribe_if(bool b) static std::size_t oversubscribe_if(bool b)
{ {
...@@ -160,34 +160,45 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -160,34 +160,45 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
assert(not ins->module_inputs().empty()); if(op.name() == "contiguous")
auto* pm = ins->module_inputs().front(); {
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}}); return replace(compile_op(
cpp_generator g; ctx,
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); to_shapes(ins->inputs()),
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); {{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}}));
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})"); }
g.add_point_op("sign", else
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"); {
g.add_point_op("equal", "migraphx::abs(${0} == ${1})"); assert(not ins->module_inputs().empty());
g.add_point_op("less", "migraphx::abs(${0} < ${1})"); auto* pm = ins->module_inputs().front();
g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
g.add_point_op("not", "migraphx::abs(not ${0})"); cpp_generator g;
// Add explict conversions g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.fresult( g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; }); g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
auto name = g.create_function( g.add_point_op("sign",
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm)); "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
std::string lambda = "MIGRAPHX_LIFT(" + name + ")"; g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
auto op_names = get_op_names(*pm); g.add_point_op("less", "migraphx::abs(${0} < ${1})");
op_names.push_back("kernel"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
auto op_name_string = join_strings(op_names, "_"); g.add_point_op("not", "migraphx::abs(not ${0})");
return replace( // Add explict conversions
compile_op(ctx, g.fresult([](const shape& s) {
to_shapes(ins->inputs()), return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}})); });
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
}
} }
}; };
} // namespace gpu } // namespace gpu
......
...@@ -18,8 +18,15 @@ struct implicit_conversion_op ...@@ -18,8 +18,15 @@ struct implicit_conversion_op
template <index_int N, class U> template <index_int N, class U>
constexpr operator vec<U, N>() const constexpr operator vec<U, N>() const
{ {
static_assert(vec_size<T>() == N, "Vector mismatch size"); if constexpr(vec_size<T>() == 0)
return __builtin_convertvector(x, vec<U, N>); {
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
} }
template <class U> template <class U>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment