Commit 0adcc72b authored by Khalique's avatar Khalique
Browse files

formatting

parent eb032acb
...@@ -326,10 +326,8 @@ struct tf_parser ...@@ -326,10 +326,8 @@ struct tf_parser
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same; op.padding_mode = op::padding_mode_t::same;
op.padding[0] = op.padding[0] = op::calculate_padding(weight_h, op.dilation[0]);
op::calculate_padding(weight_h, op.dilation[0]); op.padding[1] = op::calculate_padding(weight_w, op.dilation[1]);
op.padding[1] =
op::calculate_padding(weight_w, op.dilation[1]);
} }
else if(pad_mode.find("VALID") != std::string::npos) else if(pad_mode.find("VALID") != std::string::npos)
{ {
...@@ -362,7 +360,7 @@ struct tf_parser ...@@ -362,7 +360,7 @@ struct tf_parser
op::convolution op; op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1]; size_t num_channels = args[0]->get_shape().lens()[1];
op.group = num_channels; op.group = num_channels;
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<size_t> stride;
...@@ -401,20 +399,18 @@ struct tf_parser ...@@ -401,20 +399,18 @@ struct tf_parser
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]); weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
} }
} }
if(contains(attributes, "padding")) if(contains(attributes, "padding"))
{ {
const std::string& pad_mode = attributes.at("padding").s(); const std::string& pad_mode = attributes.at("padding").s();
std::vector<size_t> weight_dims = weights->get_shape().lens(); std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2]; size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3]; size_t weight_w = weight_dims[3];
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same; op.padding_mode = op::padding_mode_t::same;
op.padding[0] = op.padding[0] = op::calculate_padding(weight_h, op.dilation[0]);
op::calculate_padding(weight_h, op.dilation[0]); op.padding[1] = op::calculate_padding(weight_w, op.dilation[1]);
op.padding[1] =
op::calculate_padding(weight_w, op.dilation[1]);
} }
else if(pad_mode.find("VALID") != std::string::npos) else if(pad_mode.find("VALID") != std::string::npos)
{ {
...@@ -577,10 +573,8 @@ struct tf_parser ...@@ -577,10 +573,8 @@ struct tf_parser
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same; op.padding_mode = op::padding_mode_t::same;
op.padding[0] = op.padding[0] = op::calculate_padding(op.lengths[0], 0);
op::calculate_padding(op.lengths[0], 0); op.padding[1] = op::calculate_padding(op.lengths[1], 0);
op.padding[1] =
op::calculate_padding(op.lengths[1], 0);
} }
else if(pad_mode.find("VALID") != std::string::npos) else if(pad_mode.find("VALID") != std::string::npos)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment