Commit 634ea0f2 authored by Khalique's avatar Khalique
Browse files

continued tf pb progress, adjusting dims for conv

parent b12844ec
......@@ -7,7 +7,7 @@ int main(int argc, char const* argv[])
bool is_nhwc = true;
if(argc > 2)
{
if(argv[2] == "nchw")
if(strcmp(argv[2], "nchw") == 0)
is_nhwc = false;
}
std::string file = argv[1];
......
......@@ -50,10 +50,23 @@ struct tf_parser
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
}
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, f);
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{
ops.emplace(name, f);
}
template <class F>
void add_mem_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
......@@ -61,7 +74,7 @@ struct tf_parser
template <class T>
void add_binary_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
return add_broadcastable_binary_op(args[0], args[1], x);
......@@ -115,7 +128,7 @@ struct tf_parser
template <class T>
void add_generic_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
}
......@@ -125,7 +138,7 @@ struct tf_parser
{
float epsilon = 1e-4f;
float momentum = 1.f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::per_activation;
if(contains(attributes, "epsilon"))
{
epsilon = attributes.at("epsilon").f();
......@@ -182,16 +195,32 @@ struct tf_parser
}
if(contains(attributes, "strides"))
{
copy(attributes.at("strides").list().i(), op.stride.begin());
std::vector<std::size_t> stride(4);
copy(attributes.at("strides").list().i(), stride.begin());
if(stride.size() != 4)
{
MIGRAPHX_THROW("stride should have 4 values");
}
op.stride[0] = stride[0];
op.stride[1] = stride[3];
op.stride[2] = stride[1];
op.stride[3] = stride[2];
}
if(contains(attributes, "dilations"))
{
copy(attributes.at("dilations").list().i(), op.dilation.begin());
std::vector<std::size_t> dilation(4);
copy(attributes.at("dilations").list().i(), dilation.begin());
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
op.dilation[0] = dilation[0];
op.dilation[1] = dilation[3];
op.dilation[2] = dilation[1];
op.dilation[3] = dilation[2];
}
auto l0 = args[1];
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
auto l0 = prog.add_instruction(op::transpose{{2, 3, 0, 1}}, args[1]);
return prog.add_instruction(op, {args[0], l0});
}
......@@ -245,7 +274,7 @@ struct tf_parser
}
else
{
throw std::runtime_error("Failed reading");
throw std::runtime_error("Failed reading tf file");
}
}
......@@ -268,7 +297,7 @@ struct tf_parser
}
for(auto&& p : nodes)
{
this->parse_node(get_name(p.second));
this->parse_node(p.first);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment