Commit a693711d authored by Khalique's avatar Khalique
Browse files

formatting

parent 75d5c660
...@@ -11,12 +11,18 @@ inline std::size_t calculate_padding(std::size_t weight_dim, std::size_t dilatio ...@@ -11,12 +11,18 @@ inline std::size_t calculate_padding(std::size_t weight_dim, std::size_t dilatio
return (dilation * (weight_dim - 1)) / 2; return (dilation * (weight_dim - 1)) / 2;
} }
inline void calculate_padding(int64_t idx, std::vector<int64_t>& pads, int64_t input_dim, int64_t stride, int64_t dilation, int64_t weight_dim) inline void calculate_padding(int64_t idx,
std::vector<int64_t>& pads,
int64_t input_dim,
int64_t stride,
int64_t dilation,
int64_t weight_dim)
{ {
int64_t output_dim = input_dim / stride; int64_t output_dim = input_dim / stride;
int64_t pad = std::max(static_cast<int64_t>(0), (output_dim - 1) * stride + dilation * weight_dim - input_dim); int64_t pad = std::max(static_cast<int64_t>(0),
pads[idx] = pad / 2; (output_dim - 1) * stride + dilation * weight_dim - input_dim);
pads[idx + 2] = pad - pad / 2; pads[idx] = pad / 2;
pads[idx + 2] = pad - pad / 2;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -352,7 +352,8 @@ struct onnx_parser ...@@ -352,7 +352,8 @@ struct onnx_parser
{ {
// insert zeros for pad op (args[0] has 4 dims) // insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}; padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()}, l0); l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()},
l0);
} }
else else
{ {
......
...@@ -19,11 +19,12 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector ...@@ -19,11 +19,12 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector
// visit_all(result)([&](auto output) { // visit_all(result)([&](auto output) {
// auto* outptr = output.data(); // auto* outptr = output.data();
// gs_launch(stream, nelements)([=](auto i) { // gs_launch(stream, nelements)([=](auto i) {
// outptr[i] = std::numeric_limits<typename decltype(output)::value_type>::lowest(); // outptr[i] = std::numeric_limits<typename
// decltype(output)::value_type>::lowest();
// }); // });
// }); // });
// } // }
// else // else
// { // {
// visit_all(result)([&](auto output) { // visit_all(result)([&](auto output) {
......
...@@ -329,8 +329,8 @@ struct tf_parser ...@@ -329,8 +329,8 @@ struct tf_parser
size_t weight_w = weight_dims[3]; size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens(); auto input_dims = l0->get_shape().lens();
size_t input_h = input_dims[2]; size_t input_h = input_dims[2];
size_t input_w = input_dims[3]; size_t input_w = input_dims[3];
std::vector<int64_t> pads(input_dims.size()); std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_h, op.stride[0], op.dilation[0], weight_h); calculate_padding(0, pads, input_h, op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_w, op.stride[1], op.dilation[1], weight_w); calculate_padding(1, pads, input_w, op.stride[1], op.dilation[1], weight_w);
...@@ -342,8 +342,8 @@ struct tf_parser ...@@ -342,8 +342,8 @@ struct tf_parser
} }
else else
{ {
op.padding[0] = pads[0]; op.padding[0] = pads[0];
op.padding[1] = pads[1]; op.padding[1] = pads[1];
} }
} }
else if(pad_mode.find("VALID") != std::string::npos) else if(pad_mode.find("VALID") != std::string::npos)
...@@ -420,8 +420,8 @@ struct tf_parser ...@@ -420,8 +420,8 @@ struct tf_parser
auto l0 = args[0]; auto l0 = args[0];
if(contains(attributes, "padding")) if(contains(attributes, "padding"))
{ {
const std::string& pad_mode = attributes.at("padding").s(); const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
// op.padding_mode = op::padding_mode_t::same; // op.padding_mode = op::padding_mode_t::same;
...@@ -430,8 +430,8 @@ struct tf_parser ...@@ -430,8 +430,8 @@ struct tf_parser
size_t weight_w = weight_dims[3]; size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens(); auto input_dims = l0->get_shape().lens();
size_t input_h = input_dims[2]; size_t input_h = input_dims[2];
size_t input_w = input_dims[3]; size_t input_w = input_dims[3];
std::vector<int64_t> pads(input_dims.size()); std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_h, op.stride[0], op.dilation[0], weight_h); calculate_padding(0, pads, input_h, op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_w, op.stride[1], op.dilation[1], weight_w); calculate_padding(1, pads, input_w, op.stride[1], op.dilation[1], weight_w);
...@@ -443,8 +443,8 @@ struct tf_parser ...@@ -443,8 +443,8 @@ struct tf_parser
} }
else else
{ {
op.padding[0] = pads[0]; op.padding[0] = pads[0];
op.padding[1] = pads[1]; op.padding[1] = pads[1];
} }
} }
else if(pad_mode.find("VALID") != std::string::npos) else if(pad_mode.find("VALID") != std::string::npos)
...@@ -609,10 +609,10 @@ struct tf_parser ...@@ -609,10 +609,10 @@ struct tf_parser
const std::string& pad_mode = attributes.at("padding").s(); const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
//op.padding_mode = op::padding_mode_t::same; // op.padding_mode = op::padding_mode_t::same;
auto input_dims = l0->get_shape().lens(); auto input_dims = l0->get_shape().lens();
size_t input_h = input_dims[2]; size_t input_h = input_dims[2];
size_t input_w = input_dims[3]; size_t input_w = input_dims[3];
std::vector<int64_t> pads(input_dims.size()); std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_h, op.stride[0], 1, op.lengths[0]); calculate_padding(0, pads, input_h, op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_w, op.stride[1], 1, op.lengths[1]); calculate_padding(1, pads, input_w, op.stride[1], 1, op.lengths[1]);
...@@ -624,12 +624,13 @@ struct tf_parser ...@@ -624,12 +624,13 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3]) if(pads[0] != pads[2] || pads[1] != pads[3])
{ {
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = prog.add_instruction(migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0); l0 = prog.add_instruction(
migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
} }
else else
{ {
op.padding[0] = pads[0]; op.padding[0] = pads[0];
op.padding[1] = pads[1]; op.padding[1] = pads[1];
} }
} }
else if(pad_mode.find("VALID") != std::string::npos) else if(pad_mode.find("VALID") != std::string::npos)
......
...@@ -1581,7 +1581,7 @@ void pad_test() ...@@ -1581,7 +1581,7 @@ void pad_test()
auto l0 = p.add_literal(migraphx::literal{s0, data0}); auto l0 = p.add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{}; migraphx::op::pad op{};
op.value = std::numeric_limits<int8_t>::lowest(); op.value = std::numeric_limits<int8_t>::lowest();
op.pads = {0, 0, 1, 1}; op.pads = {0, 0, 1, 1};
p.add_instruction(op, l0); p.add_instruction(op, l0);
p.compile(migraphx::gpu::target{}); p.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m; migraphx::program::parameter_map m;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment