Commit 805b6f09 authored by Khalique's avatar Khalique
Browse files

formatting

parent ec794967
...@@ -269,7 +269,7 @@ struct tf_parser ...@@ -269,7 +269,7 @@ struct tf_parser
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
op::convolution op; op::convolution op;
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<size_t> stride;
...@@ -311,19 +311,31 @@ struct tf_parser ...@@ -311,19 +311,31 @@ struct tf_parser
if(contains(attributes, "padding")) if(contains(attributes, "padding"))
{ {
const std::string& pad_mode = attributes.at("padding").s(); const std::string& pad_mode = attributes.at("padding").s();
std::vector<size_t> weight_dims = weights->get_shape().lens(); std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2]; size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3]; size_t weight_w = weight_dims[3];
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
op.padding[0] = static_cast<size_t>(std::ceil(static_cast<double>(-op.stride[0] + op.dilation[0]*(weight_h-1)+1))/2); op.padding[0] =
op.padding[1] = static_cast<size_t>(std::ceil(static_cast<double>(-op.stride[1] + op.dilation[1]*(weight_w-1)+1))/2); static_cast<size_t>(std::ceil(static_cast<double>(
-op.stride[0] + op.dilation[0] * (weight_h - 1) + 1)) /
2);
op.padding[1] =
static_cast<size_t>(std::ceil(static_cast<double>(
-op.stride[1] + op.dilation[1] * (weight_w - 1) + 1)) /
2);
} }
else if(pad_mode.find("VALID") != std::string::npos) else if(pad_mode.find("VALID") != std::string::npos)
{ {
op.padding[0] = static_cast<size_t>(std::ceil(static_cast<double>(-weight_h - op.stride[0] + op.dilation[0]*(weight_h-1)+1))/2); op.padding[0] = static_cast<size_t>(
op.padding[1] = static_cast<size_t>(std::ceil(static_cast<double>(-weight_w - op.stride[1] + op.dilation[1]*(weight_w-1)+1))/2); std::ceil(static_cast<double>(-weight_h - op.stride[0] +
op.dilation[0] * (weight_h - 1) + 1)) /
2);
op.padding[1] = static_cast<size_t>(
std::ceil(static_cast<double>(-weight_w - op.stride[1] +
op.dilation[1] * (weight_w - 1) + 1)) /
2);
} }
else if(pad_mode.find("EXPLICIT") != std::string::npos) else if(pad_mode.find("EXPLICIT") != std::string::npos)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment