Commit eb0c5099 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

workaround to get mlperf model, needs general functionality

parent d7b1895a
......@@ -65,7 +65,7 @@ void auto_contiguous::apply(module& m) const
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape();
if(not s.dynamic() and not s.standard() and s.elements() != 0)
if((ins->name() == "pooling" or ins->name() == "dot") and not s.dynamic() and not s.standard() and s.elements() != 0)
{
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
......
......@@ -21,6 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/instruction_ref.hpp"
#include <cstdio>
#include <migraphx/eliminate_layout.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
......@@ -32,6 +34,8 @@
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <unordered_set>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -81,38 +85,114 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m)
return result;
}
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
void remove_layout(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::precompile_op")
if(ins->name() != "layout")
continue;
auto in_shape = ins->inputs().front()->get_shape();
if(in_shape == ins->get_shape())
m.replace_instruction(ins, ins->inputs().front());
}
}
auto precompile_op = ins->get_operator();
auto val = precompile_op.to_value();
// std::vector<instruction_ref> find_convs(const module& m)
// {
// std::vector<instruction_ref> convs;
// for(auto ins : iterator_for(m))
// {
// if(ins->name() == "gpu::miopen_op")
// convs.push_back(ins);
// }
// return convs;
// }
if(val["op"].at("name").to<std::string>() != "layout")
{
// std::cout << val["op"].at("name").to<std::string>() << std::endl;
continue;
}
m.debug_print(ins);
if(ins->get_shape() != ins->inputs().front()->get_shape())
{
std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() << std::endl;
continue;
}
if(contains(output_layouts, ins))
continue;
// void remove_layout(module& m, const std::vector<instruction_ref>& convs)
// {
// if(convs.size() < 2) return;
// m.debug_print();
m.replace_instruction(ins, ins->inputs().front());
}
}
// for(auto i = 0; i < convs.size() - 1; i++)
// {
// bool reached_start = false;
// for(auto ins : iterator_for(m))
// {
// if(ins == convs[i])
// reached_start = true;
// if(reached_start)
// {
// if(ins->name() == "gpu::pooling")
// break;
// if(ins == convs[i + 1])
// {
// m.debug_print(convs[i]->outputs().front());
// m.debug_print(convs[i]->outputs().front()->outputs().front());
// m.replace_instruction(convs[i]->outputs().front(), convs[i]->outputs().front()->outputs().front());
// std::cout << "HERE" << std::endl;
// m.debug_print(convs[i]->outputs().front());
// // m.debug_print(convs[i]->outputs().front());
// // m.debug_print(convs[i]->outputs().front()->outputs().front());
// std::cout << std::endl;
// m.debug_print(convs[i]->inputs());
// std::cout << std::endl;
// m.debug_print(convs[i + 1]->inputs());
// for(auto j = 0; j < convs[i + 1]->inputs().size(); j++)
// {
// if(convs[i]->inputs()[j] == convs[i + 1]->inputs()[j])
// {
// std::cout << "HERE2" << std::endl;
// continue;
// }
// m.replace_instruction(convs[i + 1]->inputs()[j], convs[i + 1]->inputs()[j]->inputs().front());
// m.debug_print(convs[i+1]);
// }
// break;
// }
// }
// }
// }
// }
// void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// {
// for(auto ins : iterator_for(m))
// {
// if(ins->name() != "gpu::precompile_op")
// continue;
// auto precompile_op = ins->get_operator();
// auto val = precompile_op.to_value();
// if(val["op"].at("name").to<std::string>() != "layout")
// {
// // std::cout << val["op"].at("name").to<std::string>() << std::endl;
// continue;
// }
// m.debug_print(ins);
// if(ins->get_shape() != ins->inputs().front()->get_shape())
// {
// std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() << std::endl;
// continue;
// }
// if(contains(output_layouts, ins))
// continue;
// m.replace_instruction(ins, ins->inputs().front());
// }
// }
void eliminate_layout::apply(module_pass_manager& mpm) const
{
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
remove_layout(mpm.get_module(), output_layouts);
// std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
// remove_layout(mpm.get_module(), find_convs(mpm.get_module()));
remove_layout(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
}
......
......@@ -81,17 +81,13 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
auto args = ins->inputs();
if(skip_elim_contiguous)
{
// std::cout << "HERE" << std::endl;
for(auto i = 0; i < args.size(); i++)
{
// std::cout << args[i]->name() << std::endl;
if(args[i]->name() != "layout" and args[i]->get_shape().standard())
{
// std::cout << "HERE2" << std::endl;
args[i] = m.insert_instruction(
ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), args[i]);
// m.debug_print(args);
}
}
}
......@@ -101,13 +97,30 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
auto c = conv;
if(not skip_elim_contiguous)
c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c);
// m.debug_print(conv);
// auto c = conv;
// auto nchw = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 2, 3}}}), conv);
// m.debug_print();
// if(not skip_elim_contiguous)
// c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, conv);
}
}
void insert_contiguous(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "reshape" and ins->name() != "pooling")
continue;
auto c = m.insert_instruction(ins, make_op("contiguous"), ins->inputs().front());
auto reshape = m.insert_instruction(ins, ins->get_operator(), c);
m.replace_instruction(ins, reshape);
}
std::cout << "after" << std::endl;
// m.debug_print();
}
// void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// {
// for(auto ins : iterator_for(m))
......@@ -126,6 +139,9 @@ void layout_nhwc::apply(module_pass_manager& mpm) const
{
// std::unordered_set<instruction_ref> output_layouts =
// preserve_output_layout(mpm.get_module());
insert_contiguous(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
mpm.get_module().debug_print();
transform_convolutions(mpm.get_module(), this->skip_elim_contiguous);
mpm.run_pass(dead_code_elimination{});
if(not this->skip_elim_contiguous)
......
......@@ -188,8 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b)
......
......@@ -129,6 +129,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
......@@ -144,8 +145,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
dead_code_elimination{},
// enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
// dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{},
compile_miopen{&gctx},
......
......@@ -42,3 +42,20 @@ struct test_conv_relu : verify_program<test_conv_relu>
return p;
}
};
struct test_conv_relu2 : verify_program<test_conv_relu2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 1}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto relu = mm->add_instruction(migraphx::make_op("relu"), conv);
mm->add_instruction(migraphx::make_op("convolution"), relu, weights);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment