Commit eb0c5099 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

workaround to get mlperf model, needs general functionality

parent d7b1895a
...@@ -65,7 +65,7 @@ void auto_contiguous::apply(module& m) const ...@@ -65,7 +65,7 @@ void auto_contiguous::apply(module& m) const
if(ins->outputs().empty() and ins != last) if(ins->outputs().empty() and ins != last)
continue; continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.dynamic() and not s.standard() and s.elements() != 0) if((ins->name() == "pooling" or ins->name() == "dot") and not s.dynamic() and not s.standard() and s.elements() != 0)
{ {
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c); m.replace_instruction(ins, c);
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/instruction_ref.hpp"
#include <cstdio>
#include <migraphx/eliminate_layout.hpp> #include <migraphx/eliminate_layout.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -32,6 +34,8 @@ ...@@ -32,6 +34,8 @@
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <unordered_set>
#include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -81,38 +85,114 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m) ...@@ -81,38 +85,114 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m)
return result; return result;
} }
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts) void remove_layout(module& m)
{ {
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "gpu::precompile_op") if(ins->name() != "layout")
continue; continue;
auto in_shape = ins->inputs().front()->get_shape();
if(in_shape == ins->get_shape())
m.replace_instruction(ins, ins->inputs().front());
}
}
auto precompile_op = ins->get_operator(); // std::vector<instruction_ref> find_convs(const module& m)
auto val = precompile_op.to_value(); // {
// std::vector<instruction_ref> convs;
// for(auto ins : iterator_for(m))
// {
// if(ins->name() == "gpu::miopen_op")
// convs.push_back(ins);
// }
// return convs;
// }
if(val["op"].at("name").to<std::string>() != "layout") // void remove_layout(module& m, const std::vector<instruction_ref>& convs)
{ // {
// std::cout << val["op"].at("name").to<std::string>() << std::endl; // if(convs.size() < 2) return;
continue; // m.debug_print();
}
m.debug_print(ins);
if(ins->get_shape() != ins->inputs().front()->get_shape())
{
std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() << std::endl;
continue;
}
if(contains(output_layouts, ins))
continue;
m.replace_instruction(ins, ins->inputs().front()); // for(auto i = 0; i < convs.size() - 1; i++)
} // {
} // bool reached_start = false;
// for(auto ins : iterator_for(m))
// {
// if(ins == convs[i])
// reached_start = true;
// if(reached_start)
// {
// if(ins->name() == "gpu::pooling")
// break;
// if(ins == convs[i + 1])
// {
// m.debug_print(convs[i]->outputs().front());
// m.debug_print(convs[i]->outputs().front()->outputs().front());
// m.replace_instruction(convs[i]->outputs().front(), convs[i]->outputs().front()->outputs().front());
// std::cout << "HERE" << std::endl;
// m.debug_print(convs[i]->outputs().front());
// // m.debug_print(convs[i]->outputs().front());
// // m.debug_print(convs[i]->outputs().front()->outputs().front());
// std::cout << std::endl;
// m.debug_print(convs[i]->inputs());
// std::cout << std::endl;
// m.debug_print(convs[i + 1]->inputs());
// for(auto j = 0; j < convs[i + 1]->inputs().size(); j++)
// {
// if(convs[i]->inputs()[j] == convs[i + 1]->inputs()[j])
// {
// std::cout << "HERE2" << std::endl;
// continue;
// }
// m.replace_instruction(convs[i + 1]->inputs()[j], convs[i + 1]->inputs()[j]->inputs().front());
// m.debug_print(convs[i+1]);
// }
// break;
// }
// }
// }
// }
// }
// void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// {
// for(auto ins : iterator_for(m))
// {
// if(ins->name() != "gpu::precompile_op")
// continue;
// auto precompile_op = ins->get_operator();
// auto val = precompile_op.to_value();
// if(val["op"].at("name").to<std::string>() != "layout")
// {
// // std::cout << val["op"].at("name").to<std::string>() << std::endl;
// continue;
// }
// m.debug_print(ins);
// if(ins->get_shape() != ins->inputs().front()->get_shape())
// {
// std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() << std::endl;
// continue;
// }
// if(contains(output_layouts, ins))
// continue;
// m.replace_instruction(ins, ins->inputs().front());
// }
// }
void eliminate_layout::apply(module_pass_manager& mpm) const void eliminate_layout::apply(module_pass_manager& mpm) const
{ {
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module()); // std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
remove_layout(mpm.get_module(), output_layouts); // remove_layout(mpm.get_module(), find_convs(mpm.get_module()));
remove_layout(mpm.get_module());
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
} }
......
...@@ -81,17 +81,13 @@ void transform_convolutions(module& m, bool skip_elim_contiguous) ...@@ -81,17 +81,13 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
auto args = ins->inputs(); auto args = ins->inputs();
if(skip_elim_contiguous) if(skip_elim_contiguous)
{ {
// std::cout << "HERE" << std::endl;
for(auto i = 0; i < args.size(); i++) for(auto i = 0; i < args.size(); i++)
{ {
// std::cout << args[i]->name() << std::endl;
if(args[i]->name() != "layout" and args[i]->get_shape().standard()) if(args[i]->name() != "layout" and args[i]->get_shape().standard())
{ {
// std::cout << "HERE2" << std::endl;
args[i] = m.insert_instruction( args[i] = m.insert_instruction(
ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), args[i]); ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), args[i]);
// m.debug_print(args);
} }
} }
} }
...@@ -101,13 +97,30 @@ void transform_convolutions(module& m, bool skip_elim_contiguous) ...@@ -101,13 +97,30 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i); ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
}); });
auto conv = m.insert_instruction(ins, ins->get_operator(), args); auto conv = m.insert_instruction(ins, ins->get_operator(), args);
auto c = conv; // m.debug_print(conv);
if(not skip_elim_contiguous) // auto c = conv;
c = m.insert_instruction(ins, make_op("contiguous"), conv); // auto nchw = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 2, 3}}}), conv);
m.replace_instruction(ins, c); // m.debug_print();
// if(not skip_elim_contiguous)
// c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, conv);
} }
} }
void insert_contiguous(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "reshape" and ins->name() != "pooling")
continue;
auto c = m.insert_instruction(ins, make_op("contiguous"), ins->inputs().front());
auto reshape = m.insert_instruction(ins, ins->get_operator(), c);
m.replace_instruction(ins, reshape);
}
std::cout << "after" << std::endl;
// m.debug_print();
}
// void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts) // void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// { // {
// for(auto ins : iterator_for(m)) // for(auto ins : iterator_for(m))
...@@ -126,6 +139,9 @@ void layout_nhwc::apply(module_pass_manager& mpm) const ...@@ -126,6 +139,9 @@ void layout_nhwc::apply(module_pass_manager& mpm) const
{ {
// std::unordered_set<instruction_ref> output_layouts = // std::unordered_set<instruction_ref> output_layouts =
// preserve_output_layout(mpm.get_module()); // preserve_output_layout(mpm.get_module());
insert_contiguous(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
mpm.get_module().debug_print();
transform_convolutions(mpm.get_module(), this->skip_elim_contiguous); transform_convolutions(mpm.get_module(), this->skip_elim_contiguous);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
if(not this->skip_elim_contiguous) if(not this->skip_elim_contiguous)
......
...@@ -188,8 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max) ...@@ -188,8 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())> template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b) constexpr auto max(const T& a, const T& b)
......
...@@ -129,6 +129,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -129,6 +129,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
optimize_module{}, optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}), enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{}, dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
...@@ -144,8 +145,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -144,8 +145,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}), // enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
dead_code_elimination{}, // dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
compile_miopen{&gctx}, compile_miopen{&gctx},
......
...@@ -42,3 +42,20 @@ struct test_conv_relu : verify_program<test_conv_relu> ...@@ -42,3 +42,20 @@ struct test_conv_relu : verify_program<test_conv_relu>
return p; return p;
} }
}; };
struct test_conv_relu2 : verify_program<test_conv_relu2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 1}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto relu = mm->add_instruction(migraphx::make_op("relu"), conv);
mm->add_instruction(migraphx::make_op("convolution"), relu, weights);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment