"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "c5d7524042cd82e975677cbd3185e0e7acfc19b3"
Commit 2d666954 authored by Paul's avatar Paul
Browse files

Fix parse constant

parent cdf96caf
......@@ -43,28 +43,28 @@ struct tf_parser
instruction_ref to_nhwc(instruction_ref ins)
{
if (should_transpose(ins))
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return ins;
}
instruction_ref to_nchw(instruction_ref ins)
{
if (should_transpose(ins))
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return ins;
}
instruction_ref to_kcxy(instruction_ref ins)
{
if (should_transpose(ins))
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins;
}
instruction_ref make_contiguous(instruction_ref ins)
{
if (ins->get_shape().standard())
if(ins->get_shape().standard())
return ins;
else
return prog.add_instruction(op::contiguous{}, ins);
......@@ -73,9 +73,8 @@ struct tf_parser
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
{
std::vector<instruction_ref> result(args.size());
std::transform(args.begin(), args.end(), result.begin(), [&](auto ins) {
return to_nchw(ins);
});
std::transform(
args.begin(), args.end(), result.begin(), [&](auto ins) { return to_nchw(ins); });
return result;
}
......@@ -161,7 +160,7 @@ struct tf_parser
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv, false);
add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv, false);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul, false);
......@@ -176,13 +175,15 @@ struct tf_parser
}
template <class F>
void add_op(std::string name, F f, bool transpose=true)
void add_op(std::string name, F f, bool transpose = true)
{
if (transpose)
if(transpose)
{
ops.emplace(name, op_func{[=](const attribute_map& attributes, std::vector<instruction_ref> args) -> instruction_ref {
return to_nhwc(f(attributes, to_nchw(args)));
}});
ops.emplace(name,
op_func{[=](const attribute_map& attributes,
std::vector<instruction_ref> args) -> instruction_ref {
return to_nhwc(f(attributes, to_nchw(args)));
}});
}
else
{
......@@ -191,11 +192,13 @@ struct tf_parser
}
template <class F>
void add_mem_op(std::string name, F f, bool transpose=true)
void add_mem_op(std::string name, F f, bool transpose = true)
{
add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}, transpose);
add_op(name,
[=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
},
transpose);
}
template <class T>
......@@ -261,11 +264,13 @@ struct tf_parser
}
template <class T>
void add_generic_op(std::string name, T x, bool transpose=true)
void add_generic_op(std::string name, T x, bool transpose = true)
{
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
}, transpose);
add_op(name,
[this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
},
transpose);
}
instruction_ref
......@@ -307,15 +312,7 @@ struct tf_parser
const std::vector<instruction_ref>&)
{
literal v = parse_tensor(attributes.at("value").tensor());
auto l0 = prog.add_literal(v);
size_t num_axes = l0->get_shape().lens().size();
if(num_axes >= 4)
{
std::vector<int64_t> transpose_axes = get_axes(num_axes);
reorder_data(transpose_axes);
l0 = prog.add_instruction(op::transpose{transpose_axes}, l0);
}
return l0;
return prog.add_literal(v);
}
instruction_ref
......@@ -369,7 +366,7 @@ struct tf_parser
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
return prog.add_instruction(op, {to_nchw(args[0]), to_kcxy(to_nchw(args[1]))});
return prog.add_instruction(op, {args[0], to_kcxy(args[1])});
}
instruction_ref parse_depthwiseconv(const std::string&,
......@@ -487,7 +484,8 @@ struct tf_parser
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return to_nhwc(prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args));
return to_nhwc(
prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args));
}
instruction_ref
......@@ -514,7 +512,7 @@ struct tf_parser
pads[i + ndims] = pad_per_dim[i].second;
}
op.pads = pads;
return prog.add_instruction(op, args.front());
return to_nhwc(prog.add_instruction(op, args.front()));
}
instruction_ref parse_pooling(const std::string& name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment