Commit 5e59bf45 authored by Khalique's avatar Khalique
Browse files

continued progress on tf, working on pool op

parent 39496181
...@@ -16,6 +16,13 @@ namespace migraphx { ...@@ -16,6 +16,13 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
enum padding_mode_t
{
default_, // NOLINT
same,
valid
};
struct not_computable struct not_computable
{ {
argument compute(const shape&, const std::vector<argument>&) const argument compute(const shape&, const std::vector<argument>&) const
...@@ -58,12 +65,7 @@ struct convolution ...@@ -58,12 +65,7 @@ struct convolution
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}}; std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}}; std::array<std::size_t, 2> dilation = {{1, 1}};
enum padding_mode_t
{
default_, // NOLINT
same,
valid
};
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
int group = 1; int group = 1;
...@@ -189,12 +191,14 @@ struct pooling ...@@ -189,12 +191,14 @@ struct pooling
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}}; std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}}; std::array<std::size_t, 2> lengths = {{1, 1}};
padding_mode_t padding_mode = default_;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.mode, "mode"), return pack(f(self.mode, "mode"),
f(self.padding, "padding"), f(self.padding, "padding"),
f(self.padding, "padding_mode"),
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.lengths, "lengths")); f(self.lengths, "lengths"));
} }
...@@ -211,6 +215,8 @@ struct pooling ...@@ -211,6 +215,8 @@ struct pooling
assert(lengths[0] <= (input.lens()[2] + 2 * padding[0])); assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
assert(lengths[1] <= (input.lens()[3] + 2 * padding[1])); assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
if(padding_mode == default_)
{
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
...@@ -227,6 +233,39 @@ struct pooling ...@@ -227,6 +233,39 @@ struct pooling
1)), 1)),
}}; }};
} }
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
input.lens()[1],
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
}
else if(padding_mode == valid)
{
return {t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) /
static_cast<float>(stride[0]))) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) /
static_cast<float>(stride[1]))) +
1)),
}};
}
else
{
MIGRAPHX_THROW("Invalid padding mode");
}
}
}; };
struct leaky_relu struct leaky_relu
......
...@@ -251,7 +251,7 @@ struct onnx_parser ...@@ -251,7 +251,7 @@ struct onnx_parser
if(s.find("SAME") != std::string::npos) if(s.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::convolution::same; op.padding_mode = op::padding_mode_t::same;
} }
} }
if(contains(attributes, "group")) if(contains(attributes, "group"))
......
...@@ -43,7 +43,7 @@ struct tf_parser ...@@ -43,7 +43,7 @@ struct tf_parser
add_binary_op("BiasAdd", op::add{}); add_binary_op("BiasAdd", op::add{});
// add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
// add_mem_op("ConcatV2", &tf_parser::parse_concat); // add_mem_op("ConcatV2", &tf_parser::parse_concat);
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
...@@ -176,12 +176,12 @@ struct tf_parser ...@@ -176,12 +176,12 @@ struct tf_parser
const std::string& pad_mode = attributes.at("padding").s(); const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::convolution::same; op.padding_mode = op::padding_mode_t::same;
} }
else if(pad_mode.find("EXPLICIT") != std::string::npos) else if(pad_mode.find("EXPLICIT") != std::string::npos)
{ {
std::vector<std::size_t> padding(4); std::vector<std::size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), padding.begin()); copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4) if(padding.size() != 4)
{ {
MIGRAPHX_THROW("padding should have 4 values"); MIGRAPHX_THROW("padding should have 4 values");
...@@ -196,75 +196,110 @@ struct tf_parser ...@@ -196,75 +196,110 @@ struct tf_parser
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<std::size_t> stride(4); std::vector<std::size_t> stride;
copy(attributes.at("strides").list().i(), stride.begin()); copy(attributes.at("strides").list().i(), std::back_inserter(stride));
if(stride.size() != 4) if(stride.size() != 4)
{ {
MIGRAPHX_THROW("stride should have 4 values"); MIGRAPHX_THROW("strides should have 4 values");
} }
op.stride[0] = stride[0]; if(is_nhwc)
{
op.stride[0] = stride[1];
op.stride[1] = stride[2];
}
else
{
op.stride[0] = stride[2];
op.stride[1] = stride[3]; op.stride[1] = stride[3];
op.stride[2] = stride[1]; }
op.stride[3] = stride[2];
} }
if(contains(attributes, "dilations")) if(contains(attributes, "dilations"))
{ {
std::vector<std::size_t> dilation(4); std::vector<std::size_t> dilation;
copy(attributes.at("dilations").list().i(), dilation.begin()); copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
if(dilation.size() != 4) if(dilation.size() != 4)
{ {
MIGRAPHX_THROW("dilation should have 4 values"); MIGRAPHX_THROW("dilation should have 4 values");
} }
op.dilation[0] = dilation[0]; if(is_nhwc)
{
op.dilation[0] = dilation[1];
op.dilation[1] = dilation[2];
}
else
{
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
op.dilation[2] = dilation[1]; }
op.dilation[3] = dilation[2];
} }
auto l0 = prog.add_instruction(op::transpose{{2, 3, 0, 1}}, args[1]); auto l0 = prog.add_instruction(op::transpose{{2, 3, 0, 1}}, args[1]);
return prog.add_instruction(op, {args[0], l0}); return prog.add_instruction(op, {args[0], l0});
} }
// instruction_ref parse_pooling(const std::string& name, instruction_ref parse_pooling(const std::string& name,
// attribute_map attributes, attribute_map attributes,
// std::vector<instruction_ref> args) std::vector<instruction_ref> args)
// { {
// op::pooling op{starts_with(name, "Max") ? "max" : "average"}; op::pooling op{starts_with(name, "Max") ? "max" : "average"};
// if(contains(attributes, "pads")) if(contains(attributes, "padding"))
// { {
// std::vector<std::size_t> padding(4); const std::string& pad_mode = attributes.at("padding").s();
// copy(attributes["pads"].ints(), padding.begin()); if(pad_mode.find("SAME") != std::string::npos)
// if(padding.size() != 4) {
// { op.padding_mode = op::padding_mode_t::same;
// MIGRAPHX_THROW("padding should have 4 values"); }
// } else if(pad_mode.find("VALID") != std::string::npos)
// if(padding[0] != padding[2] || padding[1] != padding[3]) {
// { op.padding_mode = op::padding_mode_t::valid;
// MIGRAPHX_THROW("migraphx does not support asymetric padding"); }
// } }
// op.padding[0] = padding[0]; if(contains(attributes, "strides"))
// op.padding[1] = padding[1]; {
// } std::vector<std::size_t> stride;
// if(contains(attributes, "strides")) copy(attributes.at("stride").list().i(), std::back_inserter(stride));
// { if(stride.size() != 4)
// copy(attributes["strides"].ints(), op.stride.begin()); {
// } MIGRAPHX_THROW("strides should have 4 values");
// if(contains(attributes, "kernel_shape")) }
// { if(is_nhwc)
// copy(attributes["kernel_shape"].ints(), op.lengths.begin()); {
// } op.stride[0] = stride[1];
// if(contains(attributes, "auto_pad")) op.stride[1] = stride[2];
// { }
// auto s = attributes["auto_pad"].s(); else
// if(to_upper(s) != "NOTSET") {
// { op.stride[0] = stride[2];
// MIGRAPHX_THROW("auto_pad is not supported for pooling"); op.stride[1] = stride[3];
// } }
// }
}
// return prog.add_instruction(op, std::move(args)); if(contains(attributes, "ksize"))
// } {
std::vector<std::size_t> ksize;
copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
if(ksize.size() != 4)
{
MIGRAPHX_THROW("ksize should have 4 values");
}
if(is_nhwc)
{
op.lengths[0] = ksize[1];
op.lengths[1] = ksize[2];
}
else
{
op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3];
}
}
return prog.add_instruction(op, std::move(args));
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment