Unverified Commit 9a5e0c06 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Asym pad refactor (#791)



* alternative impl

* formatting

* add gpu pass to insert pad

* formatting

* update onnx test, still need cleanup

* formatting

* update tf_test

* modify existing tests

* formatting

* remove print

* code cleanup

* formatting

* code cleanup

* formatting

* fix tidy and cppcheck

* remove variable

* add test

* formatting

* add test and address comments

* formatting
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 98486807
...@@ -23,6 +23,7 @@ add_library(migraphx ...@@ -23,6 +23,7 @@ add_library(migraphx
eliminate_data_type.cpp eliminate_data_type.cpp
eliminate_identity.cpp eliminate_identity.cpp
eliminate_pad.cpp eliminate_pad.cpp
insert_pad.cpp
file_buffer.cpp file_buffer.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
......
...@@ -5,76 +5,86 @@ ...@@ -5,76 +5,86 @@
#include <migraphx/op/im2col.hpp> #include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_pad::apply(module& p) const static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m)
{
for(auto ins : iterator_for(p))
{
const std::string& op_name = ins->name();
if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
continue;
auto input = ins->inputs().front();
if(input->name() != "pad")
continue;
if(op_name == "convolution" or op_name == "im2col")
update_op(input, ins, p);
else if(op_name == "pooling")
update_pooling(input, ins, p);
}
}
void eliminate_pad::update_op(const instruction_ref& input,
const instruction_ref& ins,
module& p) const
{ {
auto pad_op = any_cast<op::pad>(input->get_operator()); auto pad_op = any_cast<op::pad>(input->get_operator());
if(!pad_op.symmetric())
return;
auto kdims = input->get_shape().lens().size() - 2; auto kdims = input->get_shape().lens().size() - 2;
auto kdims_it = pad_op.pads.begin() + 2; auto kdims_it = pad_op.pads.begin() + 2;
std::vector<size_t> new_pads(kdims_it, kdims_it + kdims); std::vector<size_t> pads_l(kdims_it, kdims_it + kdims);
std::vector<size_t> pads_r(kdims_it + kdims + 2, pad_op.pads.end());
auto op = ins->get_operator(); auto op = ins->get_operator();
op.from_value({{"padding", new_pads}}); std::vector<size_t> padding(kdims * 2, 0);
std::transform(
pads_l.begin(), pads_l.end(), padding.begin(), padding.begin(), std::plus<size_t>());
std::transform(pads_r.begin(),
pads_r.end(),
padding.begin() + kdims,
padding.begin() + kdims,
std::plus<size_t>());
op.from_value({{"padding", padding}});
std::vector<instruction_ref> new_inputs{ins->inputs()}; std::vector<instruction_ref> new_inputs{ins->inputs()};
new_inputs.front() = input->inputs().front(); new_inputs.front() = input->inputs().front();
p.replace_instruction(ins, op, new_inputs); m.replace_instruction(ins, op, new_inputs);
} }
void eliminate_pad::update_pooling(const instruction_ref& input, static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
const instruction_ref& ins,
module& p) const
{ {
auto pad_op = any_cast<op::pad>(input->get_operator());
if(!pad_op.symmetric())
return;
auto kdims = input->get_shape().lens().size() - 2;
auto kdims_it = pad_op.pads.begin() + 2;
std::vector<size_t> new_pads(kdims_it, kdims_it + kdims);
auto op = any_cast<op::pooling>(ins->get_operator()); auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average") if(op.mode == "average")
{ {
return; return;
} }
auto pad_op = any_cast<op::pad>(input->get_operator());
auto kdims = input->get_shape().lens().size() - 2;
auto kdims_it = pad_op.pads.begin() + 2;
std::vector<size_t> pads_l(kdims_it, kdims_it + kdims);
std::vector<size_t> pads_r(kdims_it + kdims + 2, pad_op.pads.end());
op.padding = new_pads; std::transform(
pads_l.begin(), pads_l.end(), op.padding.begin(), op.padding.begin(), std::plus<size_t>());
std::transform(pads_r.begin(),
pads_r.end(),
op.padding.begin() + kdims,
op.padding.begin() + kdims,
std::plus<size_t>());
std::vector<instruction_ref> new_inputs{ins->inputs()}; std::vector<instruction_ref> new_inputs{ins->inputs()};
new_inputs.front() = input->inputs().front(); new_inputs.front() = input->inputs().front();
p.replace_instruction(ins, op, new_inputs); m.replace_instruction(ins, op, new_inputs);
}
void eliminate_pad::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
const std::string& op_name = ins->name();
if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
continue;
auto input = ins->inputs().front();
if(input->name() != "pad")
continue;
if(op_name == "convolution" or op_name == "im2col")
update_op(input, ins, m);
else if(op_name == "pooling")
update_pooling(input, ins, m);
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -20,10 +20,7 @@ struct eliminate_pad ...@@ -20,10 +20,7 @@ struct eliminate_pad
{ {
std::string name() const { return "eliminate_pad"; } std::string name() const { return "eliminate_pad"; }
void apply(module& p) const; void apply(module& m) const;
void update_op(const instruction_ref& input, const instruction_ref& ins, module& p) const;
void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP
#include <string>
#include <vector>
#include <array>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* insert pads if attribute of padding is asymmetrical
*/
struct insert_pad
{
std::string name() const { return "insert_pad"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -39,25 +41,30 @@ struct convolution ...@@ -39,25 +41,30 @@ struct convolution
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == dilation.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{ {
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes"); MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
} }
} }
shape compute_shape(std::vector<shape> inputs) const value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size(); check_attribute_size();
// dim num of input and attribute should match // dim num of input and attribute should match
if(inputs[0].lens().size() != padding.size() + 2) auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{ {
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
} }
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
size_t kdims = input.lens().size() - 2; size_t kdims = input_size - 2;
if(kdims != this->kdims()) if(kdims != this->kdims())
{ {
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size"); MIGRAPHX_THROW("convolution: input k-dims does not match attribute size");
...@@ -70,10 +77,13 @@ struct convolution ...@@ -70,10 +77,13 @@ struct convolution
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + (input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) / padding_factor) /
stride[i] + stride[i] +
1))); 1)));
} }
...@@ -84,7 +94,7 @@ struct convolution ...@@ -84,7 +94,7 @@ struct convolution
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
...@@ -39,7 +39,8 @@ struct deconvolution ...@@ -39,7 +39,8 @@ struct deconvolution
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == dilation.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{ {
MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes"); MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes");
} }
...@@ -72,7 +73,7 @@ struct deconvolution ...@@ -72,7 +73,7 @@ struct deconvolution
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
...@@ -31,7 +31,9 @@ struct im2col ...@@ -31,7 +31,9 @@ struct im2col
std::string name() const { return "im2col"; } std::string name() const { return "im2col"; }
shape compute_shape(std::vector<shape> inputs) const value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
auto input = inputs[0]; auto input = inputs[0];
auto weights = inputs[1]; auto weights = inputs[1];
...@@ -42,16 +44,23 @@ struct im2col ...@@ -42,16 +44,23 @@ struct im2col
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
if(batch_size != 1) if(batch_size != 1)
MIGRAPHX_THROW("im2col only support batch_size 1"); MIGRAPHX_THROW("im2col only support batch_size 1");
auto padding_h = 2 * padding[0];
auto padding_w = 2 * padding[1];
if(padding.size() == 2 * stride.size())
{
padding_h = padding[0] + padding[2];
padding_w = padding[1] + padding[3];
}
auto output_height = std::size_t(std::max<std::ptrdiff_t>( auto output_height = std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) / (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + padding_h) / stride[0] +
stride[0] +
1)); 1));
auto output_width = std::size_t(std::max<std::ptrdiff_t>( auto output_width = std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) / (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + padding_w) / stride[1] +
stride[1] +
1)); 1));
auto channels_col = kernel_height * kernel_width * input_channels; auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}}; return {input.type(), {output_height * output_width, channels_col}};
} }
......
...@@ -15,6 +15,7 @@ namespace op { ...@@ -15,6 +15,7 @@ namespace op {
// 3.1) include_min(default)/exclude_min // 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max // 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max // 4.1) exclude_max(default)/include_max
// 5) normalize padding
enum class normalize_attribute enum class normalize_attribute
{ {
use_len, use_len,
...@@ -22,7 +23,8 @@ enum class normalize_attribute ...@@ -22,7 +23,8 @@ enum class normalize_attribute
clip_max, clip_max,
clip_min, clip_min,
include_max, include_max,
include_min include_min,
normalize_padding
}; };
} // namespace op } // namespace op
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp> #include <migraphx/int_divide.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -40,29 +41,39 @@ struct pooling ...@@ -40,29 +41,39 @@ struct pooling
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == lengths.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == lengths.size()))
{ {
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes"); MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
} }
} }
shape compute_shape(std::vector<shape> inputs) const value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto input_lens = input.lens(); auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2; size_t kdims = input_lens.size() - 2;
if(kdims != this->kdims()) auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{ {
MIGRAPHX_THROW("pooling: input k-dims does not match attribute size"); MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
} }
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2); std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
std::ptrdiff_t dim_size = input_lens[i + 2] + 2 * padding[i] - lengths[i]; std::ptrdiff_t dim_size;
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
dim_size = input_lens[i + 2] + padding_factor - lengths[i];
assert(dim_size >= 0); assert(dim_size >= 0);
std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i]) std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i])
: floor_divide<std::ptrdiff_t>(dim_size, stride[i]); : floor_divide<std::ptrdiff_t>(dim_size, stride[i]);
...@@ -75,7 +86,7 @@ struct pooling ...@@ -75,7 +86,7 @@ struct pooling
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
...@@ -36,19 +36,23 @@ struct quant_convolution ...@@ -36,19 +36,23 @@ struct quant_convolution
f(self.group, "group")); f(self.group, "group"));
} }
value attributes() const { return {{"general_data_type", "convolution"}}; } value attributes() const
{
return {{"general_data_type", "convolution"}, {"normalize_padding", "padding"}};
}
std::string name() const { return "quant_convolution"; } std::string name() const { return "quant_convolution"; }
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == dilation.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{ {
MIGRAPHX_THROW("quant_convolution: inconsistent attribute sizes"); MIGRAPHX_THROW("QUANT_CONVOLUTION: inconsistent attribute sizes");
} }
} }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size(); check_attribute_size();
...@@ -70,13 +74,16 @@ struct quant_convolution ...@@ -70,13 +74,16 @@ struct quant_convolution
t = shape::int32_type; t = shape::int32_type;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size();
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + (input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) / padding_factor) /
stride[i] + stride[i] +
1))); 1)));
} }
...@@ -87,7 +94,7 @@ struct quant_convolution ...@@ -87,7 +94,7 @@ struct quant_convolution
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
#include <migraphx/insert_pad.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = ins->get_operator();
auto val = op.to_value();
auto op_padding = val.at("padding").to_vector<size_t>();
auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op_padding.begin(),
op_padding.begin() + kdims,
op_padding.begin() + kdims,
op_padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
std::vector<size_t> pads_l(op_padding.begin(), op_padding.begin() + kdims);
std::vector<size_t> pads_r(op_padding.begin() + kdims, op_padding.end());
op_padding = std::vector<size_t>(kdims * 2, 0);
op.from_value({{"padding", op_padding}});
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
auto pad_op = m.insert_instruction(ins, op::pad{padding}, input);
auto new_inputs = ins->inputs();
new_inputs.front() = pad_op;
m.replace_instruction(ins, op, new_inputs);
}
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average")
{
return;
}
auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op.padding.begin(),
op.padding.begin() + kdims,
op.padding.begin() + kdims,
op.padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
op.padding = std::vector<size_t>(kdims * 2, 0);
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
// maxpool uses lowest value for padding
float pad_val = std::numeric_limits<float>::lowest();
auto pad_op = m.insert_instruction(ins, op::pad{padding, pad_val}, input);
auto new_inputs = ins->inputs();
new_inputs.front() = pad_op;
m.replace_instruction(ins, op, new_inputs);
}
void insert_pad::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
const std::string& op_name = ins->name();
if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
continue;
auto input = ins->inputs().front();
if(op_name == "convolution" or op_name == "im2col")
update_op(input, ins, m);
else if(op_name == "pooling")
update_pooling(input, ins, m);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -117,14 +117,43 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -117,14 +117,43 @@ auto tune_attribute(const std::vector<int64_t>& vec,
return result; return result;
} }
auto tune_pad_attribute(const value& val)
{
std::vector<size_t> vec_attrs = val.to_vector<size_t>();
std::vector<size_t> result(vec_attrs.begin(), vec_attrs.end());
std::copy(vec_attrs.begin(), vec_attrs.end(), std::back_inserter(result));
return result;
}
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
{ {
bool tuned = false; bool tuned = false;
auto attrs = op.attributes(); auto attrs = op.attributes();
auto val = op.to_value(); auto val = op.to_value();
if(attrs.contains("normalize_padding"))
{
auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size();
// for now, assume the dimensions to pad start at dim 2
auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start))
tuned = true;
else if(padding_size != (lens.size() - padding_start))
MIGRAPHX_THROW("inconsistent padding size");
else
{
auto result = tune_pad_attribute(padding);
val["padding"] = result;
op.from_value(val);
tuned = true;
}
}
if(!attrs.contains("normalize_axes")) if(!attrs.contains("normalize_axes"))
{ {
return false; return tuned;
} }
auto attr_v = attrs.at("normalize_axes").without_key(); auto attr_v = attrs.at("normalize_axes").without_key();
......
...@@ -7,7 +7,7 @@ namespace onnx { ...@@ -7,7 +7,7 @@ namespace onnx {
void recalc_conv_attributes(value& v, size_t kdims) void recalc_conv_attributes(value& v, size_t kdims)
{ {
if(v["padding"].size() != kdims) if(not(v["padding"].size() == kdims or v["padding"].size() == kdims * 2))
{ {
v["padding"].resize(kdims); v["padding"].resize(kdims);
std::fill_n(v["padding"].begin(), kdims, 0); std::fill_n(v["padding"].begin(), kdims, 0);
......
...@@ -126,7 +126,7 @@ void check_asym_padding(const onnx_parser::node_info& info, ...@@ -126,7 +126,7 @@ void check_asym_padding(const onnx_parser::node_info& info,
auto left_pad_it = padding.begin(); auto left_pad_it = padding.begin();
auto right_pad_it = left_pad_it + pad_ndims; auto right_pad_it = left_pad_it + pad_ndims;
if(is_asym_padding(padding) or count_include_pad == 1) if(count_include_pad == 1)
{ {
std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C
// add left pads // add left pads
...@@ -134,10 +134,19 @@ void check_asym_padding(const onnx_parser::node_info& info, ...@@ -134,10 +134,19 @@ void check_asym_padding(const onnx_parser::node_info& info,
// add right pads // add right pads
asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end()); asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
ins = info.add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins); ins = info.add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins);
} std::vector<size_t> new_padding(padding.size());
else // subtract asym padding originally found from parsing the operator
{ std::transform(padding.begin(),
v["padding"] = std::vector<size_t>(left_pad_it, right_pad_it); left_pad_it,
asym_pads.begin() + 2,
new_padding.begin(),
std::minus<size_t>());
std::transform(right_pad_it,
padding.end(),
asym_pads.begin() + pad_ndims + 4,
new_padding.begin() + pad_ndims,
std::minus<size_t>());
v["padding"] = new_padding;
} }
} }
......
...@@ -73,7 +73,7 @@ struct parse_convolution : op_parser<parse_convolution> ...@@ -73,7 +73,7 @@ struct parse_convolution : op_parser<parse_convolution>
values["padding_mode"] = to_value(op::padding_mode_t::same); values["padding_mode"] = to_value(op::padding_mode_t::same);
} }
} }
check_asym_padding(info, l0, padding, values); values["padding"] = std::vector<size_t>(padding.begin(), padding.end());
if(contains(info.attributes, "group")) if(contains(info.attributes, "group"))
{ {
......
...@@ -133,18 +133,11 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -133,18 +133,11 @@ struct parse_pooling : op_parser<parse_pooling>
slice_end.begin(), slice_end.begin(),
[](auto i, auto j) { return i + j; }); [](auto i, auto j) { return i + j; });
} }
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val); check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
in_lens = l0->get_shape().lens();
for(size_t i = 0; i < kdims; i++)
{
if(values["lengths"][i].to<int64_t>() >
in_lens[i + 2] + 2 * values["padding"][i].to<int64_t>())
{
MIGRAPHX_THROW("PARSE_POOLING: kernel shape is too large");
}
}
op.from_value(values); op.from_value(values);
auto l1 = info.add_instruction(op, l0); auto l1 = info.add_instruction(op, l0);
if(!slice_start.empty()) if(!slice_start.empty())
{ {
......
...@@ -40,6 +40,9 @@ struct dnnl_convolution ...@@ -40,6 +40,9 @@ struct dnnl_convolution
auto dilation = op.dilation; auto dilation = op.dilation;
std::transform( std::transform(
dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; }); dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; });
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
return {dnnl::prop_kind::forward_inference, return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_auto, dnnl::algorithm::convolution_auto,
m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_SRC),
...@@ -47,8 +50,8 @@ struct dnnl_convolution ...@@ -47,8 +50,8 @@ struct dnnl_convolution
m.at(DNNL_ARG_DST), m.at(DNNL_ARG_DST),
to_dnnl_dims(op.stride), to_dnnl_dims(op.stride),
to_dnnl_dims(dilation), to_dnnl_dims(dilation),
to_dnnl_dims(op.padding), to_dnnl_dims(padding_l),
to_dnnl_dims(op.padding)}; to_dnnl_dims(padding_r)};
} }
}; };
......
...@@ -66,7 +66,10 @@ struct cpu_im2col ...@@ -66,7 +66,10 @@ struct cpu_im2col
} }
static std::string name() { return "cpu::im2col"; } static std::string name() { return "cpu::im2col"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
......
...@@ -63,7 +63,7 @@ struct cpu_pooling : auto_register_op<cpu_pooling<Op>> ...@@ -63,7 +63,7 @@ struct cpu_pooling : auto_register_op<cpu_pooling<Op>>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
inputs.pop_back(); inputs.pop_back();
return op.compute_shape(inputs); return op.normalize_compute_shape(inputs);
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
...@@ -130,14 +130,17 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po ...@@ -130,14 +130,17 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg; auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg;
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
return {dnnl::prop_kind::forward_inference, return {dnnl::prop_kind::forward_inference,
algo, algo,
m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_SRC),
m.at(DNNL_ARG_DST), m.at(DNNL_ARG_DST),
to_dnnl_dims(op.stride), to_dnnl_dims(op.stride),
to_dnnl_dims(op.lengths), to_dnnl_dims(op.lengths),
to_dnnl_dims(op.padding), to_dnnl_dims(padding_l),
to_dnnl_dims(op.padding)}; to_dnnl_dims(padding_r)};
} }
}; };
......
...@@ -11,7 +11,7 @@ shape miopen_convolution::compute_shape(const std::vector<shape>& inputs) const ...@@ -11,7 +11,7 @@ shape miopen_convolution::compute_shape(const std::vector<shape>& inputs) const
check_shapes{inputs, *this}.has(4).standard(); check_shapes{inputs, *this}.has(4).standard();
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2); std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, *this}.max_ndims(5); check_shapes{conv_inputs, *this}.max_ndims(5);
return op.compute_shape(conv_inputs); return op.normalize_compute_shape(conv_inputs);
} }
inline shape reshape_if_1d(const shape& input) inline shape reshape_if_1d(const shape& input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment