Commit 02326004 authored by Khalique's avatar Khalique
Browse files

formatting

parent 9f25ffb7
......@@ -717,10 +717,10 @@ struct pad
bool symmetric() const
{
std::size_t num_dims = pads.size()/2;
std::size_t num_dims = pads.size() / 2;
for(std::size_t i = 0; i < num_dims; i++)
{
if(pads.at(i) != pads.at(i+num_dims))
if(pads.at(i) != pads.at(i + num_dims))
return false;
}
return true;
......
......@@ -14,7 +14,7 @@ void pad_rewrite::apply(program& p) const
{
if(ins->name() != "pad")
continue;
for (auto output : ins->outputs())
for(auto output : ins->outputs())
{
auto op_name = output->name();
if(op_name == "convolution")
......@@ -27,25 +27,25 @@ void pad_rewrite::apply(program& p) const
}
}
template<class T>
template <class T>
void pad_rewrite::update_op(T, instruction_ref ins, instruction_ref output, program& p) const
{
auto pad_op = any_cast<op::pad>(ins->get_operator());
if(!pad_op.symmetric())
return;
std::vector<int64_t> pads = pad_op.pads;
assert(pads.size() == 8); // ensure input being padded has 4 dims (*2 for font and back padding)
std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]),static_cast<size_t>(pads[3])};
std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]), static_cast<size_t>(pads[3])};
T op = any_cast<T>(output->get_operator());
T op = any_cast<T>(output->get_operator());
op.padding = new_pads;
std::vector<instruction_ref> new_inputs{output->inputs()};
new_inputs.front() = ins->inputs().front();
p.replace_instruction(output, op, new_inputs);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment