Unverified Commit 785307c3 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Handle removing contiguous on operators that use modules (#1005)

Currently, eliminate_contiguous will never remove contiguous for operators that use module inputs due to the fact that it doesn't pass the module inputs to compute_shape.

- Update to pass the module inputs correctly to compute_shape
- Fix the overloads of compute_shape so that when passed an empty vector of module inputs it will call the overload without module inputs
- Add tests with contiguous and pointwise module function.
- Move add_pointwise function to a seperate header to reuse across different tests
parent 19f65e7e
...@@ -11,11 +11,13 @@ ...@@ -11,11 +11,13 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs) static bool try_compute_shape(instruction_ref ins,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mods)
{ {
try try
{ {
shape new_shape = ins->get_operator().compute_shape(inputs); shape new_shape = ins->get_operator().compute_shape(inputs, mods);
// If the output shape is a standard shape, no need to try its output // If the output shape is a standard shape, no need to try its output
if(new_shape.standard()) if(new_shape.standard())
{ {
...@@ -45,7 +47,7 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inp ...@@ -45,7 +47,7 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inp
return (arg == ins) ? new_shape : arg->get_shape(); return (arg == ins) ? new_shape : arg->get_shape();
}); });
if(!try_compute_shape(output, input_shapes)) if(!try_compute_shape(output, input_shapes, mods))
{ {
return false; return false;
} }
...@@ -59,10 +61,12 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inp ...@@ -59,10 +61,12 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inp
return true; return true;
} }
static bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args) static bool try_compute_shape(instruction_ref ins,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods)
{ {
auto inputs = to_shapes(args); auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs); return try_compute_shape(ins, inputs, mods);
} }
void eliminate_contiguous::apply(module& p) const void eliminate_contiguous::apply(module& p) const
...@@ -82,7 +86,7 @@ void eliminate_contiguous::apply(module& p) const ...@@ -82,7 +86,7 @@ void eliminate_contiguous::apply(module& p) const
auto new_args = args; auto new_args = args;
auto prev = arg->inputs().front(); auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args)) if(try_compute_shape(ins, new_args, ins->module_inputs()))
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
......
...@@ -103,7 +103,14 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -103,7 +103,14 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators } // namespace operation_operators
template <class T> template <class T>
auto normalize_compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) auto compute_shape_op(rank<3>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs))
{
return x.compute_shape(inputs);
}
template <class T>
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
...@@ -112,27 +119,27 @@ auto normalize_compute_shape_op(rank<2>, const T& x, const std::vector<shape>& i ...@@ -112,27 +119,27 @@ auto normalize_compute_shape_op(rank<2>, const T& x, const std::vector<shape>& i
} }
template <class T> template <class T>
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs) auto compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs, {})) -> decltype(x.compute_shape(inputs, {}))
{ {
return x.compute_shape(inputs, {}); return x.compute_shape(inputs, {});
} }
template <class T> template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&) shape compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name); MIGRAPHX_THROW("Shape not computable: " + name);
} }
template <class T> template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs) shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
{ {
return normalize_compute_shape_op(rank<2>{}, x, inputs); return compute_shape_op(rank<3>{}, x, inputs);
} }
template <class T> template <class T>
auto compute_shape_op(rank<1>, auto mod_compute_shape_op(rank<1>,
const T& x, const T& x,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
...@@ -142,47 +149,23 @@ auto compute_shape_op(rank<1>, ...@@ -142,47 +149,23 @@ auto compute_shape_op(rank<1>,
} }
template <class T> template <class T>
shape shape mod_compute_shape_op(rank<0>,
compute_shape_op(rank<0>, const T& x, const std::vector<shape>&, const std::vector<module_ref>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
auto normalize_compute_shape_op(rank<1>,
const T& x, const T& x,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
-> decltype(x.normalize_compute_shape(inputs, mod_args))
{
return x.normalize_compute_shape(inputs, mod_args);
}
template <class T>
shape normalize_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>&,
const std::vector<module_ref>&)
{ {
if(mod_args.empty())
return compute_shape_op(x, inputs);
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name); MIGRAPHX_THROW("Shape not computable: " + name);
} }
template <class T> template <class T>
shape normalize_compute_shape_op(const T& x, shape mod_compute_shape_op(const T& x,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
{ {
return normalize_compute_shape_op(rank<1>{}, x, inputs, mod_args); return mod_compute_shape_op(rank<1>{}, x, inputs, mod_args);
} }
template <class T> template <class T>
...@@ -855,7 +838,7 @@ struct operation ...@@ -855,7 +838,7 @@ struct operation
T&& private_detail_te_self, T&& private_detail_te_self,
const std::vector<shape>& input) const std::vector<shape>& input)
{ {
return detail::normalize_compute_shape_op(private_detail_te_self, input); return detail::compute_shape_op(private_detail_te_self, input);
} }
template <class T> template <class T>
...@@ -874,7 +857,7 @@ struct operation ...@@ -874,7 +857,7 @@ struct operation
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
{ {
return detail::compute_shape_op(private_detail_te_self, inputs, mod_args); return detail::mod_compute_shape_op(private_detail_te_self, inputs, mod_args);
} }
template <class T> template <class T>
...@@ -1276,7 +1259,7 @@ template <class T> ...@@ -1276,7 +1259,7 @@ template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs) inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
-> decltype(op.normalize_compute_shape(inputs)) -> decltype(op.normalize_compute_shape(inputs))
{ {
return detail::normalize_compute_shape_op(op, inputs); return detail::compute_shape_op(op, inputs);
} }
inline shape compute_shape(const operation& op, inline shape compute_shape(const operation& op,
...@@ -1301,7 +1284,7 @@ inline auto compute_shape(const T& op, ...@@ -1301,7 +1284,7 @@ inline auto compute_shape(const T& op,
const std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
-> decltype(op.normalize_compute_shape(inputs, mod_args)) -> decltype(op.normalize_compute_shape(inputs, mod_args))
{ {
return detail::normalize_compute_shape_op(op, inputs, mod_args); return detail::compute_shape_op(op, inputs, mod_args);
} }
inline bool is_context_free(const operation& op) { return op.is_context_free(); } inline bool is_context_free(const operation& op) { return op.is_context_free(); }
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
void validate_pass(module& mod, const pass& p, tracer trace) void validate_pass(module& mod, const pass& p, tracer trace)
{ {
(void)mod; (void)mod;
...@@ -82,6 +84,8 @@ module& get_module(module_pass_manager& mpm) { return mpm.get_module(); } ...@@ -82,6 +84,8 @@ module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{ {
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes) for(const auto& p : passes)
{ {
module_pm{&mod, nullptr, &trace}.run_pass(p); module_pm{&mod, nullptr, &trace}.run_pass(p);
...@@ -90,6 +94,8 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) ...@@ -90,6 +94,8 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{ {
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes) for(const auto& p : passes)
{ {
auto mods = prog.get_modules(); auto mods = prog.get_modules();
......
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <pointwise.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::module& m) void run_pass(migraphx::module& m)
...@@ -159,4 +161,25 @@ TEST_CASE(standard_flatten_op) ...@@ -159,4 +161,25 @@ TEST_CASE(standard_flatten_op)
EXPECT(std::distance(m.begin(), m.end()) == (count - 1)); EXPECT(std::distance(m.begin(), m.end()) == (count - 1));
} }
TEST_CASE(contiguous_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3, 8, 8}};
migraphx::program p;
auto* mm = p.get_main_module();
{
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {3}});
auto yb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y);
auto yc = mm->add_instruction(migraphx::make_op("contiguous"), yb);
auto add = add_pointwise(p, "main:pointwise0", {x, yc}, single_pointwise("add"));
mm->add_instruction(pass_op{}, add);
}
auto count = std::distance(mm->begin(), mm->end());
run_pass(*mm);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
EXPECT(std::none_of(
mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "contiguous"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -7,38 +7,13 @@ ...@@ -7,38 +7,13 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
{ {
migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}});
} }
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = p.create_module(name);
auto* mm = p.get_main_module();
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
auto r = f(pm, params);
pm->add_return({r});
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
}
auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op(name), inputs);
};
}
TEST_CASE(single) TEST_CASE(single)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
......
File mode changed from 100644 to 100755
#ifndef MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
#define MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
#include <migraphx/program.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = p.create_module(name);
auto* mm = p.get_main_module();
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
auto r = f(pm, params);
pm->add_return({r});
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
}
inline auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op(name), inputs);
};
}
#endif // MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
...@@ -103,7 +103,14 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -103,7 +103,14 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators } // namespace operation_operators
template <class T> template <class T>
auto normalize_compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) auto compute_shape_op(rank<3>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs))
{
return x.compute_shape(inputs);
}
template <class T>
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
...@@ -112,27 +119,27 @@ auto normalize_compute_shape_op(rank<2>, const T& x, const std::vector<shape>& i ...@@ -112,27 +119,27 @@ auto normalize_compute_shape_op(rank<2>, const T& x, const std::vector<shape>& i
} }
template <class T> template <class T>
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs) auto compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs, {})) -> decltype(x.compute_shape(inputs, {}))
{ {
return x.compute_shape(inputs, {}); return x.compute_shape(inputs, {});
} }
template <class T> template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&) shape compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name); MIGRAPHX_THROW("Shape not computable: " + name);
} }
template <class T> template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs) shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
{ {
return normalize_compute_shape_op(rank<2>{}, x, inputs); return compute_shape_op(rank<3>{}, x, inputs);
} }
template <class T> template <class T>
auto compute_shape_op(rank<1>, auto mod_compute_shape_op(rank<1>,
const T& x, const T& x,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
...@@ -142,47 +149,23 @@ auto compute_shape_op(rank<1>, ...@@ -142,47 +149,23 @@ auto compute_shape_op(rank<1>,
} }
template <class T> template <class T>
shape shape mod_compute_shape_op(rank<0>,
compute_shape_op(rank<0>, const T& x, const std::vector<shape>&, const std::vector<module_ref>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
auto normalize_compute_shape_op(rank<1>,
const T& x, const T& x,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
-> decltype(x.normalize_compute_shape(inputs, mod_args))
{
return x.normalize_compute_shape(inputs, mod_args);
}
template <class T>
shape normalize_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>&,
const std::vector<module_ref>&)
{ {
if(mod_args.empty())
return compute_shape_op(x, inputs);
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name); MIGRAPHX_THROW("Shape not computable: " + name);
} }
template <class T> template <class T>
shape normalize_compute_shape_op(const T& x, shape mod_compute_shape_op(const T& x,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
{ {
return normalize_compute_shape_op(rank<1>{}, x, inputs, mod_args); return mod_compute_shape_op(rank<1>{}, x, inputs, mod_args);
} }
template <class T> template <class T>
...@@ -495,13 +478,13 @@ lifetime get_lifetime_op(const T&) ...@@ -495,13 +478,13 @@ lifetime get_lifetime_op(const T&)
returns = 'shape', returns = 'shape',
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
const = True, const = True,
default = 'detail::normalize_compute_shape_op'), default = 'detail::compute_shape_op'),
virtual('compute_shape', virtual('compute_shape',
returns = 'shape', returns = 'shape',
inputs = 'const std::vector<shape>&', inputs = 'const std::vector<shape>&',
mod_args = 'const std::vector<module_ref>&', mod_args = 'const std::vector<module_ref>&',
const = True, const = True,
default = 'detail::compute_shape_op'), default = 'detail::mod_compute_shape_op'),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
ctx = 'context&', ctx = 'context&',
...@@ -589,7 +572,7 @@ template <class T> ...@@ -589,7 +572,7 @@ template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs) inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
-> decltype(op.normalize_compute_shape(inputs)) -> decltype(op.normalize_compute_shape(inputs))
{ {
return detail::normalize_compute_shape_op(op, inputs); return detail::compute_shape_op(op, inputs);
} }
inline shape compute_shape(const operation& op, inline shape compute_shape(const operation& op,
...@@ -614,7 +597,7 @@ inline auto compute_shape(const T& op, ...@@ -614,7 +597,7 @@ inline auto compute_shape(const T& op,
const std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
-> decltype(op.normalize_compute_shape(inputs, mod_args)) -> decltype(op.normalize_compute_shape(inputs, mod_args))
{ {
return detail::normalize_compute_shape_op(op, inputs, mod_args); return detail::compute_shape_op(op, inputs, mod_args);
} }
inline bool is_context_free(const operation& op) { return op.is_context_free(); } inline bool is_context_free(const operation& op) { return op.is_context_free(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment