Commit 02326004 authored by Khalique's avatar Khalique
Browse files

formatting

parent 9f25ffb7
...@@ -717,10 +717,10 @@ struct pad ...@@ -717,10 +717,10 @@ struct pad
bool symmetric() const bool symmetric() const
{ {
std::size_t num_dims = pads.size()/2; std::size_t num_dims = pads.size() / 2;
for(std::size_t i = 0; i < num_dims; i++) for(std::size_t i = 0; i < num_dims; i++)
{ {
if(pads.at(i) != pads.at(i+num_dims)) if(pads.at(i) != pads.at(i + num_dims))
return false; return false;
} }
return true; return true;
......
...@@ -14,7 +14,7 @@ void pad_rewrite::apply(program& p) const ...@@ -14,7 +14,7 @@ void pad_rewrite::apply(program& p) const
{ {
if(ins->name() != "pad") if(ins->name() != "pad")
continue; continue;
for (auto output : ins->outputs()) for(auto output : ins->outputs())
{ {
auto op_name = output->name(); auto op_name = output->name();
if(op_name == "convolution") if(op_name == "convolution")
...@@ -27,25 +27,25 @@ void pad_rewrite::apply(program& p) const ...@@ -27,25 +27,25 @@ void pad_rewrite::apply(program& p) const
} }
} }
template<class T> template <class T>
void pad_rewrite::update_op(T, instruction_ref ins, instruction_ref output, program& p) const void pad_rewrite::update_op(T, instruction_ref ins, instruction_ref output, program& p) const
{ {
auto pad_op = any_cast<op::pad>(ins->get_operator()); auto pad_op = any_cast<op::pad>(ins->get_operator());
if(!pad_op.symmetric()) if(!pad_op.symmetric())
return; return;
std::vector<int64_t> pads = pad_op.pads; std::vector<int64_t> pads = pad_op.pads;
assert(pads.size() == 8); // ensure input being padded has 4 dims (*2 for font and back padding) assert(pads.size() == 8); // ensure input being padded has 4 dims (*2 for font and back padding)
std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]),static_cast<size_t>(pads[3])}; std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]), static_cast<size_t>(pads[3])};
T op = any_cast<T>(output->get_operator()); T op = any_cast<T>(output->get_operator());
op.padding = new_pads; op.padding = new_pads;
std::vector<instruction_ref> new_inputs{output->inputs()}; std::vector<instruction_ref> new_inputs{output->inputs()};
new_inputs.front() = ins->inputs().front(); new_inputs.front() = ins->inputs().front();
p.replace_instruction(output, op, new_inputs); p.replace_instruction(output, op, new_inputs);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment