Commit 557aafa0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents 8de5bf9c d20bd8a1
...@@ -82,7 +82,6 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed ...@@ -82,7 +82,6 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
{ {
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed}); std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; }); // std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; }); // std::generate(result.begin(), result.end(), []{ return 1; });
return result; return result;
......
...@@ -48,7 +48,6 @@ struct dot ...@@ -48,7 +48,6 @@ struct dot
"} x {" + to_string_range(b.lens()) + "}"); "} x {" + to_string_range(b.lens()) + "}");
} }
// dims for batch should be standard
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1; std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
......
...@@ -66,8 +66,8 @@ struct onnx_parser ...@@ -66,8 +66,8 @@ struct onnx_parser
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("ArgMax", &onnx_parser::parse_argmax); add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
add_mem_op("ArgMin", &onnx_parser::parse_argmin); add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
add_mem_op("Cast", &onnx_parser::parse_cast); add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
...@@ -275,7 +275,8 @@ struct onnx_parser ...@@ -275,7 +275,8 @@ struct onnx_parser
return prog.add_instruction(Op{axis}, std::move(args)); return prog.add_instruction(Op{axis}, std::move(args));
} }
instruction_ref parse_argmax(const std::string&, template<class Op>
instruction_ref parse_arg_op(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
...@@ -293,39 +294,12 @@ struct onnx_parser ...@@ -293,39 +294,12 @@ struct onnx_parser
if(keep_dims == 0) if(keep_dims == 0)
{ {
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args)); auto ins = prog.add_instruction(Op{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins); return prog.add_instruction(op::squeeze{{axis}}, ins);
} }
else else
{ {
return prog.add_instruction(op::argmax{axis}, std::move(args)); return prog.add_instruction(Op{axis}, std::move(args));
}
}
instruction_ref parse_argmin(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmin{axis}, std::move(args));
} }
} }
......
...@@ -177,7 +177,7 @@ struct tf_parser ...@@ -177,7 +177,7 @@ struct tf_parser
add_mem_op("Pack", &tf_parser::parse_pack, false); add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape, false); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
add_mem_op("Transpose", &tf_parser::parse_transpose, false); add_mem_op("Transpose", &tf_parser::parse_transpose, false);
...@@ -705,6 +705,8 @@ struct tf_parser ...@@ -705,6 +705,8 @@ struct tf_parser
} }
} }
// template to facilitate the logsoftmax later
template <class Op>
instruction_ref parse_softmax(const std::string&, instruction_ref parse_softmax(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
...@@ -715,7 +717,7 @@ struct tf_parser ...@@ -715,7 +717,7 @@ struct tf_parser
axis = static_cast<int>(attributes.at("axis").i()); axis = static_cast<int>(attributes.at("axis").i());
} }
return prog.add_instruction(op::softmax{axis}, std::move(args)); return prog.add_instruction(Op{axis}, std::move(args));
} }
instruction_ref parse_squeeze(const std::string&, instruction_ref parse_squeeze(const std::string&,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment