"git@developer.sourcefind.cn:modelzoo/gpt2_onnxruntime.git" did not exist on "01a5955c2f5f185595b569790422381dede5e955"
Unverified Commit 157935ff authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Conditionally enable pointwise fusion (#992)

This enables the pointwise fusions using the MIGRAPHX_ENABLE_POINTWISE_FUSION env variable. Its disabled by default since MIOpen fusions need to be refactored.

This also adds a compile_ops pass to compile the pointwise modules. All tests except test_gpu_fast_math passes with MIGRAPHX_ENABLE_POINTWISE_FUSION=1 set.
parent 38287064
...@@ -26,16 +26,18 @@ cpp_generator::function::set_body(const module& m, const cpp_generator::generate ...@@ -26,16 +26,18 @@ cpp_generator::function::set_body(const module& m, const cpp_generator::generate
{ {
names[ins] = names[ins] =
migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter; migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter;
continue;
} }
if(ins->name() == "@return") else if(ins->name() == "@return")
{ {
assert(ins->inputs().size() == 1); assert(ins->inputs().size() == 1);
return_ins = ins->inputs().front(); return_ins = ins->inputs().front();
} }
std::string n = "z" + std::to_string(names.size()); else
names[ins] = n; {
ss << "auto " << n << " = " << g(ins, names) << ";\n"; std::string n = "z" + std::to_string(names.size());
names[ins] = n;
ss << "auto " << n << " = " << g(ins, names) << ";\n";
}
} }
ss << "return " << names.at(return_ins) << ";\n"; ss << "return " << names.at(return_ins) << ";\n";
body = ss.str(); body = ss.str();
...@@ -84,8 +86,11 @@ void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { imp ...@@ -84,8 +86,11 @@ void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { imp
std::string cpp_generator::generate_point_op(const operation& op, std::string cpp_generator::generate_point_op(const operation& op,
const std::vector<std::string>& args) const std::vector<std::string>& args)
{ {
auto v = op.to_value(); auto v = op.to_value();
return interpolate_string(op.attributes()["point_op"].to<std::string>(), auto attributes = op.attributes();
if(not attributes.contains("point_op"))
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
return interpolate_string(attributes["point_op"].to<std::string>(),
[&](auto start, auto last) -> std::string { [&](auto start, auto last) -> std::string {
auto key = trim({start, last}); auto key = trim({start, last});
if(key.empty()) if(key.empty())
...@@ -120,7 +125,12 @@ std::string cpp_generator::str() const { return impl->fs.str(); } ...@@ -120,7 +125,12 @@ std::string cpp_generator::str() const { return impl->fs.str(); }
cpp_generator::function cpp_generator::generate_module(const module& m) cpp_generator::function cpp_generator::generate_module(const module& m)
{ {
function f; function f;
f.set_name(m.name()).set_types(m).set_body( auto name = transform_string(m.name(), [](char c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return '_';
});
f.set_name(name).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string { m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal") if(ins->name() == "@literal")
return shape::cpp_type(ins->get_shape().type()) + "(" + return shape::cpp_type(ins->get_shape().type()) + "(" +
...@@ -130,7 +140,6 @@ cpp_generator::function cpp_generator::generate_module(const module& m) ...@@ -130,7 +140,6 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(args), std::back_inserter(args),
[&](auto i) { return names.at(i); }); [&](auto i) { return names.at(i); });
auto s = this->generate_point_op(ins->get_operator(), args);
return this->generate_point_op(ins->get_operator(), args); return this->generate_point_op(ins->get_operator(), args);
}); });
return f; return f;
......
...@@ -13,6 +13,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,6 +13,8 @@ inline namespace MIGRAPHX_INLINE_NS {
static literal get_scalar(instruction_ref ins) static literal get_scalar(instruction_ref ins)
{ {
if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape(); const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar())) if(not(s.elements() == 1 or s.scalar()))
return {}; return {};
...@@ -31,11 +33,16 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -31,11 +33,16 @@ static void create_pointwise_modules(module_pass_manager& mpm)
{ {
if(not ins->get_operator().attributes().get("pointwise", false)) if(not ins->get_operator().attributes().get("pointwise", false))
continue; continue;
auto* pm = mpm.create_module("pointwise" + std::to_string(n++)); // Skip convert op for now
if(ins->name() == "convert")
continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass(); pm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map; std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs; std::vector<instruction_ref> pointwise_inputs;
std::size_t i = 0;
for(auto input : ins->inputs()) for(auto input : ins->inputs())
{ {
if(contains(param_map, input)) if(contains(param_map, input))
...@@ -44,8 +51,9 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -44,8 +51,9 @@ static void create_pointwise_modules(module_pass_manager& mpm)
if(scalar.empty()) if(scalar.empty())
{ {
pointwise_inputs.push_back(input); pointwise_inputs.push_back(input);
param_map[input] = pm->add_parameter("x" + std::to_string(param_map.size()), param_map[input] =
shape{input->get_shape().type()}); pm->add_parameter("x" + std::to_string(i), shape{input->get_shape().type()});
i++;
} }
else else
{ {
...@@ -68,6 +76,7 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -68,6 +76,7 @@ static void create_pointwise_modules(module_pass_manager& mpm)
static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins, static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
instruction_ref output) instruction_ref output)
{ {
assert(contains(output->inputs(), ins));
module_ref pm = ins->module_inputs().at(0); module_ref pm = ins->module_inputs().at(0);
module_ref xm = output->module_inputs().at(0); module_ref xm = output->module_inputs().at(0);
...@@ -75,14 +84,18 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins, ...@@ -75,14 +84,18 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
assert(last->name() == "@return"); assert(last->name() == "@return");
assert(last->inputs().size() == 1); assert(last->inputs().size() == 1);
assert(pm->get_parameter_names().size() == ins->inputs().size());
assert(xm->get_parameter_names().size() == output->inputs().size());
std::vector<instruction_ref> inputs = ins->inputs(); std::vector<instruction_ref> inputs = ins->inputs();
std::unordered_map<instruction_ref, instruction_ref> map_ins; std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::unordered_map<instruction_ref, instruction_ref> input_map; std::unordered_map<instruction_ref, instruction_ref> input_map;
// Copy inputs to input_map // Copy inputs to input_map
for(auto i : range(inputs.size())) for(auto i : range(inputs.size()))
{ {
auto input = inputs[i]; auto input = inputs[i];
auto param = pm->get_parameter("x" + std::to_string(i)); auto param = pm->get_parameter("x" + std::to_string(i));
assert(param != pm->end());
input_map[input] = param; input_map[input] = param;
} }
// Add the new parameter and additional inputs // Add the new parameter and additional inputs
...@@ -90,6 +103,7 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins, ...@@ -90,6 +103,7 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
{ {
auto input = output->inputs()[i]; auto input = output->inputs()[i];
auto param = xm->get_parameter("x" + std::to_string(i)); auto param = xm->get_parameter("x" + std::to_string(i));
assert(param != xm->end());
if(input == ins) if(input == ins)
{ {
map_ins[param] = last->inputs().front(); map_ins[param] = last->inputs().front();
......
...@@ -26,19 +26,17 @@ struct pointwise ...@@ -26,19 +26,17 @@ struct pointwise
auto pnames = pm->get_parameter_names(); auto pnames = pm->get_parameter_names();
std::sort(pnames.begin(), pnames.end()); std::sort(pnames.begin(), pnames.end());
check_shapes{inputs, *this}.has(pnames.size()).same_dims(); check_shapes{inputs, *this}.has(pnames.size()).same_dims();
for(auto i : range(pnames.size()))
{
auto s1 = pm->get_parameter(pnames[i])->get_shape();
auto s2 = inputs[i];
if(s1.type() != s2.type())
MIGRAPHX_THROW("Mismatch type");
}
if(pm->get_output_shapes().size() != 1) if(pm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("submodule should have only one output."); MIGRAPHX_THROW("submodule should have only one output.");
auto type = pm->get_output_shapes().front().type(); auto type = pm->get_output_shapes().front().type();
// Scalar output if all inputs are scalar
if(inputs.front().elements() == 1 and
all_of(inputs, [](const auto& s) { return s.scalar(); }))
return shape{type};
return shape::from_permutation(type, inputs.front().lens(), find_permutation(inputs)); return shape::from_permutation(type, inputs.front().lens(), find_permutation(inputs));
} }
......
...@@ -9,6 +9,7 @@ namespace op { ...@@ -9,6 +9,7 @@ namespace op {
struct prelu : binary<prelu> struct prelu : binary<prelu>
{ {
std::string point_op() const { return "(${0} < 0) ? (${0} * ${1}) : ${0}"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto slope) { return ((x < 0) ? (x * slope) : x); }; return [](auto x, auto slope) { return ((x < 0) ? (x * slope) : x); };
......
...@@ -9,6 +9,7 @@ namespace op { ...@@ -9,6 +9,7 @@ namespace op {
struct recip : unary<recip> struct recip : unary<recip>
{ {
std::string point_op() const { return "1 / ${0}"; }
auto apply() const auto apply() const
{ {
return [](auto x) { return 1 / x; }; return [](auto x) { return 1 / x; };
......
...@@ -18,6 +18,7 @@ namespace op { ...@@ -18,6 +18,7 @@ namespace op {
struct sigmoid : unary<sigmoid> struct sigmoid : unary<sigmoid>
{ {
std::string point_op() const { return "1.f / (1.f + ${function:exp}(-${0}))"; }
auto apply() const auto apply() const
{ {
return [](auto x) { return 1.f / (1.f + std::exp(-x)); }; return [](auto x) { return 1.f / (1.f + std::exp(-x)); };
......
...@@ -18,6 +18,7 @@ namespace op { ...@@ -18,6 +18,7 @@ namespace op {
struct sign : unary<sign> struct sign : unary<sign>
{ {
std::string point_op() const { return "(${0} > 0 ? 1 : ((${0} < 0) ? -1 : 0))"; }
auto apply() const auto apply() const
{ {
return [](auto x) { return (x > 0 ? 1 : ((x < 0) ? -1 : 0)); }; return [](auto x) { return (x > 0 ? 1 : ((x < 0) ? -1 : 0)); };
......
...@@ -103,7 +103,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -103,7 +103,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators } // namespace operation_operators
template <class T> template <class T>
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs) auto normalize_compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
...@@ -111,6 +111,13 @@ auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& i ...@@ -111,6 +111,13 @@ auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& i
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
template <class T>
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs, {}))
{
return x.compute_shape(inputs, {});
}
template <class T> template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&) shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{ {
...@@ -121,7 +128,7 @@ shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&) ...@@ -121,7 +128,7 @@ shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
template <class T> template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs) shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs)
{ {
return normalize_compute_shape_op(rank<1>{}, x, inputs); return normalize_compute_shape_op(rank<2>{}, x, inputs);
} }
template <class T> template <class T>
......
...@@ -18,7 +18,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -18,7 +18,7 @@ inline namespace MIGRAPHX_INLINE_NS {
template <class F> template <class F>
auto with_char(F f) auto with_char(F f)
{ {
return [=](unsigned char c) { return f(c); }; return [=](unsigned char c) -> bool { return f(c); };
} }
inline std::string inline std::string
...@@ -120,22 +120,27 @@ interpolate_string(const std::string& input, F f, std::string start = "${", std: ...@@ -120,22 +120,27 @@ interpolate_string(const std::string& input, F f, std::string start = "${", std:
result.append(it, next_start); result.append(it, next_start);
if(next_start == input.end()) if(next_start == input.end())
break; break;
auto r = f(next_start + start.size(), next_end - end.size() + 1); auto r = f(next_start + start.size(), next_end);
result.append(r.begin(), r.end()); result.append(r.begin(), r.end());
it = next_end + 1; it = next_end + end.size();
} }
return result; return result;
} }
inline std::string interpolate_string(const std::string& input, inline std::string interpolate_string(const std::string& input,
const std::unordered_map<std::string, std::string>& vars) const std::unordered_map<std::string, std::string>& vars,
{ std::string start = "${",
return interpolate_string(input, [&](auto start, auto last) { std::string end = "}")
auto key = trim({start, last}); {
auto it = vars.find(key); return interpolate_string(input,
if(it == vars.end()) [&](auto start_it, auto last_it) {
throw std::runtime_error("Unknown key: " + key); auto key = trim({start_it, last_it});
return it->second; auto it = vars.find(key);
}); if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
} }
template <class Iterator> template <class Iterator>
......
...@@ -122,6 +122,7 @@ add_library(migraphx_gpu ...@@ -122,6 +122,7 @@ add_library(migraphx_gpu
batch_norm_inference.cpp batch_norm_inference.cpp
clip.cpp clip.cpp
code_object_op.cpp code_object_op.cpp
compile_ops.cpp
compile_hip.cpp compile_hip.cpp
compile_hip_code_object.cpp compile_hip_code_object.cpp
compile_pointwise.cpp compile_pointwise.cpp
......
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compile_pointwise.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct precompile_op
{
operation op = op::identity{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::precompile_op"; }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
inputs.pop_back();
return op.compute_shape(inputs, mods);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
MIGRAPHX_REGISTER_OP(precompile_op);
struct pointwise_compiler
{
std::string name() const { return "pointwise"; }
operation apply(context& ctx, instruction_ref ins, const operation&) const
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
return compile_pointwise(ctx, to_shapes(ins->inputs()), *pm);
}
};
using compiler_function = std::function<operation(context&, instruction_ref, operation)>;
template <class T>
compiler_function make_compiler_function(T x)
{
return {[=](auto&&... xs) { return x.apply(xs...); }};
}
template <class... Ts>
std::unordered_map<std::string, compiler_function> make_compilers(Ts... xs)
{
return {{xs.name(), make_compiler_function(xs)}...};
}
void compile_ops::apply(module& m) const
{
auto compilers = make_compilers(pointwise_compiler{});
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::precompile_op")
continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op;
assert(contains(compilers, preop.name()));
auto op = compilers[preop.name()](*ctx, ins, preop);
m.replace_instruction(ins, op, ins->inputs());
}
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2,9 +2,14 @@ ...@@ -2,9 +2,14 @@
#include <migraphx/gpu/compile_hip_code_object.hpp> #include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -17,6 +22,8 @@ static const char* const pointwise_kernel = R"__migraphx__( ...@@ -17,6 +22,8 @@ static const char* const pointwise_kernel = R"__migraphx__(
using namespace migraphx; using namespace migraphx;
${preamble}
extern "C" { extern "C" {
__global__ void kernel(${params}) __global__ void kernel(${params})
{ {
...@@ -29,7 +36,10 @@ int main() {} ...@@ -29,7 +36,10 @@ int main() {}
)__migraphx__"; )__migraphx__";
operation compile_pointwise(context&, const std::vector<shape>& inputs, const std::string& lambda) operation compile_pointwise(context&,
const std::vector<shape>& inputs,
const std::string& lambda,
const std::string& preamble)
{ {
hip_compile_options options; hip_compile_options options;
options.global = compute_global(inputs.front().elements()); options.global = compute_global(inputs.front().elements());
...@@ -37,13 +47,23 @@ operation compile_pointwise(context&, const std::vector<shape>& inputs, const st ...@@ -37,13 +47,23 @@ operation compile_pointwise(context&, const std::vector<shape>& inputs, const st
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.reduced_inputs = reduce_dims(inputs); options.reduced_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")}, {{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", lambda}}); {"lambda", lambda},
{"preamble", preamble}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, module m)
{
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
auto name = g.create_function(g.generate_module(m).set_attributes({"__device__"}));
return compile_pointwise((ctx), inputs, "&" + name, g.str());
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string> #include <string>
namespace migraphx { namespace migraphx {
......
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct context;
struct compile_ops
{
context* ctx = nullptr;
std::string name() const { return "gpu::compile_ops"; }
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP
...@@ -6,11 +6,17 @@ ...@@ -6,11 +6,17 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu { namespace gpu {
struct context; struct context;
operation operation compile_pointwise(context& ctx,
compile_pointwise(context& ctx, const std::vector<shape>& inputs, const std::string& lambda); const std::vector<shape>& inputs,
const std::string& lambda,
const std::string& preamble = "");
operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, module m);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -23,7 +23,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) ...@@ -23,7 +23,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
template <class F, class... Ts> template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps) __device__ void pointwise(F f, Ts*... ps)
{ {
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize()); auto t = transform_args(make_tensors(), rotate_last());
t(ps...)([&](auto... xs) { t(ps...)([&](auto... xs) {
auto idx = make_index(); auto idx = make_index();
pointwise_tensor(idx, f, xs...); pointwise_tensor(idx, f, xs...);
......
...@@ -12,6 +12,8 @@ using index_int = std::uint32_t; ...@@ -12,6 +12,8 @@ using index_int = std::uint32_t;
template <class T, index_int N> template <class T, index_int N>
using vec = T __attribute__((ext_vector_type(N))); using vec = T __attribute__((ext_vector_type(N)));
using half = _Float16;
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -183,6 +183,8 @@ struct miopen_apply ...@@ -183,6 +183,8 @@ struct miopen_apply
add_extend_op("softmax"); add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
add_precompile_op("pointwise");
add_batch_norm_inference_op(); add_batch_norm_inference_op();
add_convolution_op(); add_convolution_op();
add_deconvolution_op(); add_deconvolution_op();
...@@ -381,6 +383,21 @@ struct miopen_apply ...@@ -381,6 +383,21 @@ struct miopen_apply
}); });
} }
void add_precompile_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return mod->replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
refs,
ins->module_inputs());
});
}
void add_batch_norm_inference_op() void add_batch_norm_inference_op()
{ {
apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) { apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp> #include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp> #include <migraphx/insert_pad.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
...@@ -25,6 +26,7 @@ ...@@ -25,6 +26,7 @@
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/eliminate_workspace.hpp> #include <migraphx/gpu/eliminate_workspace.hpp>
...@@ -42,6 +44,20 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -42,6 +44,20 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_POINTWISE_FUSION)
struct id_pass
{
std::string name() const { return "id"; }
void apple(const module&) const {}
};
pass enable_pass(bool enabled, pass p)
{
if(enabled)
return p;
return id_pass{};
}
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const
{ {
...@@ -84,6 +100,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -84,6 +100,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
mlir_conv{&ctx}, mlir_conv{&ctx},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
...@@ -96,6 +114,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -96,6 +114,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
compile_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})}, schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"}, memory_coloring{"hip::allocate"},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment