Commit 75bb9a6f authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent eb0c5099
......@@ -65,7 +65,8 @@ void auto_contiguous::apply(module& m) const
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape();
if((ins->name() == "pooling" or ins->name() == "dot") and not s.dynamic() and not s.standard() and s.elements() != 0)
if((ins->name() == "pooling" or ins->name() == "dot") and not s.dynamic() and
not s.standard() and s.elements() != 0)
{
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
......
......@@ -129,9 +129,9 @@ void remove_layout(module& m)
// m.debug_print(convs[i]->outputs().front());
// m.debug_print(convs[i]->outputs().front()->outputs().front());
// m.replace_instruction(convs[i]->outputs().front(), convs[i]->outputs().front()->outputs().front());
// std::cout << "HERE" << std::endl;
// m.debug_print(convs[i]->outputs().front());
// m.replace_instruction(convs[i]->outputs().front(),
// convs[i]->outputs().front()->outputs().front()); std::cout << "HERE" <<
// std::endl; m.debug_print(convs[i]->outputs().front());
// // m.debug_print(convs[i]->outputs().front());
// // m.debug_print(convs[i]->outputs().front()->outputs().front());
......@@ -148,18 +148,17 @@ void remove_layout(module& m)
// std::cout << "HERE2" << std::endl;
// continue;
// }
// m.replace_instruction(convs[i + 1]->inputs()[j], convs[i + 1]->inputs()[j]->inputs().front());
// m.debug_print(convs[i+1]);
// m.replace_instruction(convs[i + 1]->inputs()[j], convs[i +
// 1]->inputs()[j]->inputs().front()); m.debug_print(convs[i+1]);
// }
// break;
// }
// }
// }
// }
// }
// void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// {
// for(auto ins : iterator_for(m))
......@@ -178,8 +177,8 @@ void remove_layout(module& m)
// m.debug_print(ins);
// if(ins->get_shape() != ins->inputs().front()->get_shape())
// {
// std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() << std::endl;
// continue;
// std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() <<
// std::endl; continue;
// }
// if(contains(output_layouts, ins))
// continue;
......@@ -190,8 +189,9 @@ void remove_layout(module& m)
void eliminate_layout::apply(module_pass_manager& mpm) const
{
// std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
// remove_layout(mpm.get_module(), find_convs(mpm.get_module()));
// std::unordered_set<instruction_ref> output_layouts =
// preserve_output_layout(mpm.get_module()); remove_layout(mpm.get_module(),
// find_convs(mpm.get_module()));
remove_layout(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
}
......
......@@ -99,9 +99,8 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
// m.debug_print(conv);
// auto c = conv;
// auto nchw = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 2, 3}}}), conv);
// m.debug_print();
// if(not skip_elim_contiguous)
// auto nchw = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 2, 3}}}),
// conv); m.debug_print(); if(not skip_elim_contiguous)
// c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, conv);
}
......@@ -113,7 +112,7 @@ void insert_contiguous(module& m)
{
if(ins->name() != "reshape" and ins->name() != "pooling")
continue;
auto c = m.insert_instruction(ins, make_op("contiguous"), ins->inputs().front());
auto c = m.insert_instruction(ins, make_op("contiguous"), ins->inputs().front());
auto reshape = m.insert_instruction(ins, ins->get_operator(), c);
m.replace_instruction(ins, reshape);
}
......
......@@ -188,8 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment