Commit 634ea0f2 authored by Khalique's avatar Khalique
Browse files

continued tf pb progress, adjusting dims for conv

parent b12844ec
...@@ -7,7 +7,7 @@ int main(int argc, char const* argv[]) ...@@ -7,7 +7,7 @@ int main(int argc, char const* argv[])
bool is_nhwc = true; bool is_nhwc = true;
if(argc > 2) if(argc > 2)
{ {
if(argv[2] == "nchw") if(strcmp(argv[2], "nchw") == 0)
is_nhwc = false; is_nhwc = false;
} }
std::string file = argv[1]; std::string file = argv[1];
......
...@@ -50,10 +50,23 @@ struct tf_parser ...@@ -50,10 +50,23 @@ struct tf_parser
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
} }
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, f);
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{
ops.emplace(name, f);
}
template <class F> template <class F>
void add_mem_op(std::string name, F f) void add_mem_op(std::string name, F f)
{ {
ops.emplace(name, [=](auto&&... xs) { add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); });
} }
...@@ -61,7 +74,7 @@ struct tf_parser ...@@ -61,7 +74,7 @@ struct tf_parser
template <class T> template <class T>
void add_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
return add_broadcastable_binary_op(args[0], args[1], x); return add_broadcastable_binary_op(args[0], args[1], x);
...@@ -115,7 +128,7 @@ struct tf_parser ...@@ -115,7 +128,7 @@ struct tf_parser
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
...@@ -125,7 +138,7 @@ struct tf_parser ...@@ -125,7 +138,7 @@ struct tf_parser
{ {
float epsilon = 1e-4f; float epsilon = 1e-4f;
float momentum = 1.f; float momentum = 1.f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial; op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::per_activation;
if(contains(attributes, "epsilon")) if(contains(attributes, "epsilon"))
{ {
epsilon = attributes.at("epsilon").f(); epsilon = attributes.at("epsilon").f();
...@@ -182,16 +195,32 @@ struct tf_parser ...@@ -182,16 +195,32 @@ struct tf_parser
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
copy(attributes.at("strides").list().i(), op.stride.begin()); std::vector<std::size_t> stride(4);
copy(attributes.at("strides").list().i(), stride.begin());
if(stride.size() != 4)
{
MIGRAPHX_THROW("stride should have 4 values");
}
op.stride[0] = stride[0];
op.stride[1] = stride[3];
op.stride[2] = stride[1];
op.stride[3] = stride[2];
} }
if(contains(attributes, "dilations")) if(contains(attributes, "dilations"))
{ {
copy(attributes.at("dilations").list().i(), op.dilation.begin()); std::vector<std::size_t> dilation(4);
copy(attributes.at("dilations").list().i(), dilation.begin());
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
op.dilation[0] = dilation[0];
op.dilation[1] = dilation[3];
op.dilation[2] = dilation[1];
op.dilation[3] = dilation[2];
} }
auto l0 = args[1]; auto l0 = prog.add_instruction(op::transpose{{2, 3, 0, 1}}, args[1]);
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
return prog.add_instruction(op, {args[0], l0}); return prog.add_instruction(op, {args[0], l0});
} }
...@@ -245,7 +274,7 @@ struct tf_parser ...@@ -245,7 +274,7 @@ struct tf_parser
} }
else else
{ {
throw std::runtime_error("Failed reading"); throw std::runtime_error("Failed reading tf file");
} }
} }
...@@ -268,7 +297,7 @@ struct tf_parser ...@@ -268,7 +297,7 @@ struct tf_parser
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
this->parse_node(get_name(p.second)); this->parse_node(p.first);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment