Commit eb032acb authored by Khalique's avatar Khalique
Browse files

separate pad calc to function

parent e41cbd34
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
size_t calculate_padding(size_t weight_dim, size_t dilation)
{
return (dilation * (weight_dim - 1)) / 2;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/op/pad_calc.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -326,13 +327,9 @@ struct tf_parser ...@@ -326,13 +327,9 @@ struct tf_parser
{ {
op.padding_mode = op::padding_mode_t::same; op.padding_mode = op::padding_mode_t::same;
op.padding[0] = op.padding[0] =
static_cast<size_t>(std::ceil(static_cast<double>( op::calculate_padding(weight_h, op.dilation[0]);
-op.stride[0] + op.dilation[0] * (weight_h - 1) + 1)) /
2);
op.padding[1] = op.padding[1] =
static_cast<size_t>(std::ceil(static_cast<double>( op::calculate_padding(weight_w, op.dilation[1]);
-op.stride[1] + op.dilation[1] * (weight_w - 1) + 1)) /
2);
} }
else if(pad_mode.find("VALID") != std::string::npos) else if(pad_mode.find("VALID") != std::string::npos)
{ {
...@@ -365,14 +362,7 @@ struct tf_parser ...@@ -365,14 +362,7 @@ struct tf_parser
op::convolution op; op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1]; size_t num_channels = args[0]->get_shape().lens()[1];
op.group = num_channels; op.group = num_channels;
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
}
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<size_t> stride;
...@@ -385,6 +375,19 @@ struct tf_parser ...@@ -385,6 +375,19 @@ struct tf_parser
op.stride[0] = stride[2]; op.stride[0] = stride[2];
op.stride[1] = stride[3]; op.stride[1] = stride[3];
} }
if(contains(attributes, "dilations"))
{
std::vector<size_t> dilation;
copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
reorder_data(dilation);
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
auto weights = args[1]; auto weights = args[1];
// check if weights are from a constant // check if weights are from a constant
if(weights->name() != "@param") if(weights->name() != "@param")
...@@ -399,6 +402,26 @@ struct tf_parser ...@@ -399,6 +402,26 @@ struct tf_parser
} }
} }
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
op.padding[0] =
op::calculate_padding(weight_h, op.dilation[0]);
op.padding[1] =
op::calculate_padding(weight_w, op.dilation[1]);
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
std::vector<int64_t> new_weights_shape; std::vector<int64_t> new_weights_shape;
copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape)); copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
...@@ -524,18 +547,6 @@ struct tf_parser ...@@ -524,18 +547,6 @@ struct tf_parser
{ {
op::pooling op{starts_with(name, "Max") ? "max" : "average"}; op::pooling op{starts_with(name, "Max") ? "max" : "average"};
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<size_t> stride;
...@@ -560,6 +571,22 @@ struct tf_parser ...@@ -560,6 +571,22 @@ struct tf_parser
op.lengths[0] = ksize[2]; op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3]; op.lengths[1] = ksize[3];
} }
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
op.padding[0] =
op::calculate_padding(op.lengths[0], 0);
op.padding[1] =
op::calculate_padding(op.lengths[1], 0);
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment