"vscode:/vscode.git/clone" did not exist on "9eca5cbee221399c7abfbc003fd98ec9dd82884d"
Commit 2d666954 authored by Paul's avatar Paul
Browse files

Fix parse constant

parent cdf96caf
...@@ -43,28 +43,28 @@ struct tf_parser ...@@ -43,28 +43,28 @@ struct tf_parser
instruction_ref to_nhwc(instruction_ref ins) instruction_ref to_nhwc(instruction_ref ins)
{ {
if (should_transpose(ins)) if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins); return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return ins; return ins;
} }
instruction_ref to_nchw(instruction_ref ins) instruction_ref to_nchw(instruction_ref ins)
{ {
if (should_transpose(ins)) if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins); return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return ins; return ins;
} }
instruction_ref to_kcxy(instruction_ref ins) instruction_ref to_kcxy(instruction_ref ins)
{ {
if (should_transpose(ins)) if(should_transpose(ins))
return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins); return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins; return ins;
} }
instruction_ref make_contiguous(instruction_ref ins) instruction_ref make_contiguous(instruction_ref ins)
{ {
if (ins->get_shape().standard()) if(ins->get_shape().standard())
return ins; return ins;
else else
return prog.add_instruction(op::contiguous{}, ins); return prog.add_instruction(op::contiguous{}, ins);
...@@ -73,9 +73,8 @@ struct tf_parser ...@@ -73,9 +73,8 @@ struct tf_parser
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args) std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
{ {
std::vector<instruction_ref> result(args.size()); std::vector<instruction_ref> result(args.size());
std::transform(args.begin(), args.end(), result.begin(), [&](auto ins) { std::transform(
return to_nchw(ins); args.begin(), args.end(), result.begin(), [&](auto ins) { return to_nchw(ins); });
});
return result; return result;
} }
...@@ -161,7 +160,7 @@ struct tf_parser ...@@ -161,7 +160,7 @@ struct tf_parser
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false); add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv, false); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv, false); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv, false);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul, false); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
...@@ -176,13 +175,15 @@ struct tf_parser ...@@ -176,13 +175,15 @@ struct tf_parser
} }
template <class F> template <class F>
void add_op(std::string name, F f, bool transpose=true) void add_op(std::string name, F f, bool transpose = true)
{ {
if (transpose) if(transpose)
{ {
ops.emplace(name, op_func{[=](const attribute_map& attributes, std::vector<instruction_ref> args) -> instruction_ref { ops.emplace(name,
return to_nhwc(f(attributes, to_nchw(args))); op_func{[=](const attribute_map& attributes,
}}); std::vector<instruction_ref> args) -> instruction_ref {
return to_nhwc(f(attributes, to_nchw(args)));
}});
} }
else else
{ {
...@@ -191,11 +192,13 @@ struct tf_parser ...@@ -191,11 +192,13 @@ struct tf_parser
} }
template <class F> template <class F>
void add_mem_op(std::string name, F f, bool transpose=true) void add_mem_op(std::string name, F f, bool transpose = true)
{ {
add_op(name, [=](auto&&... xs) { add_op(name,
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); [=](auto&&... xs) {
}, transpose); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
},
transpose);
} }
template <class T> template <class T>
...@@ -261,11 +264,13 @@ struct tf_parser ...@@ -261,11 +264,13 @@ struct tf_parser
} }
template <class T> template <class T>
void add_generic_op(std::string name, T x, bool transpose=true) void add_generic_op(std::string name, T x, bool transpose = true)
{ {
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) { add_op(name,
return prog.add_instruction(x, args); [this, x](const attribute_map&, std::vector<instruction_ref> args) {
}, transpose); return prog.add_instruction(x, args);
},
transpose);
} }
instruction_ref instruction_ref
...@@ -307,15 +312,7 @@ struct tf_parser ...@@ -307,15 +312,7 @@ struct tf_parser
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_tensor(attributes.at("value").tensor()); literal v = parse_tensor(attributes.at("value").tensor());
auto l0 = prog.add_literal(v); return prog.add_literal(v);
size_t num_axes = l0->get_shape().lens().size();
if(num_axes >= 4)
{
std::vector<int64_t> transpose_axes = get_axes(num_axes);
reorder_data(transpose_axes);
l0 = prog.add_instruction(op::transpose{transpose_axes}, l0);
}
return l0;
} }
instruction_ref instruction_ref
...@@ -369,7 +366,7 @@ struct tf_parser ...@@ -369,7 +366,7 @@ struct tf_parser
op.dilation[0] = dilation[2]; op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
return prog.add_instruction(op, {to_nchw(args[0]), to_kcxy(to_nchw(args[1]))}); return prog.add_instruction(op, {args[0], to_kcxy(args[1])});
} }
instruction_ref parse_depthwiseconv(const std::string&, instruction_ref parse_depthwiseconv(const std::string&,
...@@ -487,7 +484,8 @@ struct tf_parser ...@@ -487,7 +484,8 @@ struct tf_parser
args.end(), args.end(),
std::back_inserter(unsqueezed_args), std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); }); [&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return to_nhwc(prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args)); return to_nhwc(
prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args));
} }
instruction_ref instruction_ref
...@@ -514,7 +512,7 @@ struct tf_parser ...@@ -514,7 +512,7 @@ struct tf_parser
pads[i + ndims] = pad_per_dim[i].second; pads[i + ndims] = pad_per_dim[i].second;
} }
op.pads = pads; op.pads = pads;
return prog.add_instruction(op, args.front()); return to_nhwc(prog.add_instruction(op, args.front()));
} }
instruction_ref parse_pooling(const std::string& name, instruction_ref parse_pooling(const std::string& name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment