Commit ec794967 authored by Khalique's avatar Khalique
Browse files

set pad attributes instead of padding_mode

parent db70de8e
...@@ -270,29 +270,6 @@ struct tf_parser ...@@ -270,29 +270,6 @@ struct tf_parser
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
op::convolution op; op::convolution op;
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<size_t> stride;
...@@ -332,6 +309,39 @@ struct tf_parser ...@@ -332,6 +309,39 @@ struct tf_parser
} }
} }
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding[0] = static_cast<size_t>(std::ceil(static_cast<double>(-op.stride[0] + op.dilation[0]*(weight_h-1)+1))/2);
op.padding[1] = static_cast<size_t>(std::ceil(static_cast<double>(-op.stride[1] + op.dilation[1]*(weight_w-1)+1))/2);
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding[0] = static_cast<size_t>(std::ceil(static_cast<double>(-weight_h - op.stride[0] + op.dilation[0]*(weight_h-1)+1))/2);
op.padding[1] = static_cast<size_t>(std::ceil(static_cast<double>(-weight_w - op.stride[1] + op.dilation[1]*(weight_w-1)+1))/2);
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
return prog.add_instruction(op, {args[0], weights}); return prog.add_instruction(op, {args[0], weights});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment