Commit 5e59bf45 authored by Khalique's avatar Khalique
Browse files

continued progress on tf, working on pool op

parent 39496181
......@@ -16,6 +16,13 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
enum padding_mode_t
{
default_, // NOLINT
same,
valid
};
struct not_computable
{
argument compute(const shape&, const std::vector<argument>&) const
......@@ -58,12 +65,7 @@ struct convolution
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
enum padding_mode_t
{
default_, // NOLINT
same,
valid
};
padding_mode_t padding_mode = default_;
int group = 1;
......@@ -189,12 +191,14 @@ struct pooling
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}};
padding_mode_t padding_mode = default_;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.mode, "mode"),
f(self.padding, "padding"),
f(self.padding, "padding_mode"),
f(self.stride, "stride"),
f(self.lengths, "lengths"));
}
......@@ -211,6 +215,8 @@ struct pooling
assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
if(padding_mode == default_)
{
return {t,
{
input.lens()[0],
......@@ -227,6 +233,39 @@ struct pooling
1)),
}};
}
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
input.lens()[1],
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
}
else if(padding_mode == valid)
{
return {t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) /
static_cast<float>(stride[0]))) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) /
static_cast<float>(stride[1]))) +
1)),
}};
}
else
{
MIGRAPHX_THROW("Invalid padding mode");
}
}
};
struct leaky_relu
......
......@@ -251,7 +251,7 @@ struct onnx_parser
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::convolution::same;
op.padding_mode = op::padding_mode_t::same;
}
}
if(contains(attributes, "group"))
......
......@@ -43,7 +43,7 @@ struct tf_parser
add_binary_op("BiasAdd", op::add{});
// add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("AvgPool", &tf_parser::parse_pooling);
// add_mem_op("ConcatV2", &tf_parser::parse_concat);
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv);
......@@ -176,12 +176,12 @@ struct tf_parser
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::convolution::same;
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<std::size_t> padding(4);
copy(attributes.at("explicit_paddings").list().i(), padding.begin());
std::vector<std::size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
......@@ -196,75 +196,110 @@ struct tf_parser
}
if(contains(attributes, "strides"))
{
std::vector<std::size_t> stride(4);
copy(attributes.at("strides").list().i(), stride.begin());
std::vector<std::size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride));
if(stride.size() != 4)
{
MIGRAPHX_THROW("stride should have 4 values");
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[0];
if(is_nhwc)
{
op.stride[0] = stride[1];
op.stride[1] = stride[2];
}
else
{
op.stride[0] = stride[2];
op.stride[1] = stride[3];
op.stride[2] = stride[1];
op.stride[3] = stride[2];
}
}
if(contains(attributes, "dilations"))
{
std::vector<std::size_t> dilation(4);
copy(attributes.at("dilations").list().i(), dilation.begin());
std::vector<std::size_t> dilation;
copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
op.dilation[0] = dilation[0];
if(is_nhwc)
{
op.dilation[0] = dilation[1];
op.dilation[1] = dilation[2];
}
else
{
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
op.dilation[2] = dilation[1];
op.dilation[3] = dilation[2];
}
}
auto l0 = prog.add_instruction(op::transpose{{2, 3, 0, 1}}, args[1]);
return prog.add_instruction(op, {args[0], l0});
}
// instruction_ref parse_pooling(const std::string& name,
// attribute_map attributes,
// std::vector<instruction_ref> args)
// {
// op::pooling op{starts_with(name, "Max") ? "max" : "average"};
// if(contains(attributes, "pads"))
// {
// std::vector<std::size_t> padding(4);
// copy(attributes["pads"].ints(), padding.begin());
// if(padding.size() != 4)
// {
// MIGRAPHX_THROW("padding should have 4 values");
// }
// if(padding[0] != padding[2] || padding[1] != padding[3])
// {
// MIGRAPHX_THROW("migraphx does not support asymetric padding");
// }
// op.padding[0] = padding[0];
// op.padding[1] = padding[1];
// }
// if(contains(attributes, "strides"))
// {
// copy(attributes["strides"].ints(), op.stride.begin());
// }
// if(contains(attributes, "kernel_shape"))
// {
// copy(attributes["kernel_shape"].ints(), op.lengths.begin());
// }
// if(contains(attributes, "auto_pad"))
// {
// auto s = attributes["auto_pad"].s();
// if(to_upper(s) != "NOTSET")
// {
// MIGRAPHX_THROW("auto_pad is not supported for pooling");
// }
// }
// return prog.add_instruction(op, std::move(args));
// }
instruction_ref parse_pooling(const std::string& name,
attribute_map attributes,
std::vector<instruction_ref> args)
{
op::pooling op{starts_with(name, "Max") ? "max" : "average"};
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
if(contains(attributes, "strides"))
{
std::vector<std::size_t> stride;
copy(attributes.at("stride").list().i(), std::back_inserter(stride));
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
if(is_nhwc)
{
op.stride[0] = stride[1];
op.stride[1] = stride[2];
}
else
{
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
}
if(contains(attributes, "ksize"))
{
std::vector<std::size_t> ksize;
copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
if(ksize.size() != 4)
{
MIGRAPHX_THROW("ksize should have 4 values");
}
if(is_nhwc)
{
op.lengths[0] = ksize[1];
op.lengths[1] = ksize[2];
}
else
{
op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3];
}
}
return prog.add_instruction(op, std::move(args));
}
void parse_from(std::istream& is)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment