Commit 78db8596 authored by Khalique's avatar Khalique
Browse files

manual merge

parents f5886e64 abf1b8e4
......@@ -66,8 +66,8 @@ struct onnx_parser
add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{});
add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn);
......@@ -86,8 +86,8 @@ struct onnx_parser
add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("MatMul", &onnx_parser::parse_matmul);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax);
add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax);
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice);
......@@ -261,17 +261,8 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto dims = args.front()->get_shape().lens();
auto r =
prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
auto s = prog.add_instruction(op::softmax{}, r);
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
}
instruction_ref parse_logsoftmax(const std::string&,
template <class Op>
instruction_ref parse_softmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
......@@ -281,37 +272,11 @@ struct onnx_parser
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
}
instruction_ref parse_argmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmax{axis}, std::move(args));
}
return prog.add_instruction(Op{axis}, std::move(args));
}
instruction_ref parse_argmin(const std::string&,
template <class Op>
instruction_ref parse_arg_op(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
......@@ -329,12 +294,12 @@ struct onnx_parser
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
auto ins = prog.add_instruction(Op{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(Op{axis}, std::move(args));
}
}
......@@ -460,18 +425,10 @@ struct onnx_parser
{
op::reshape op;
if(args.size() == 1)
{
if(contains(attributes, "shape"))
{
literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
else
{
MIGRAPHX_THROW(
"Parse_reshape: shape attribute is needed when only one argument is provided!");
}
}
if(args.size() == 2)
{
auto s = args[1]->eval();
......@@ -863,7 +820,7 @@ struct onnx_parser
{
dtype = parse_value(attributes.at("dtype")).at<int>();
}
migraphx::shape::type_t type = get_type(dtype);
shape::type_t type = get_type(dtype);
if(contains(attributes, "input_as_shape"))
{
......
......@@ -178,7 +178,7 @@ void miopen_gemm::batch_not_transposed(const std::vector<std::size_t>& strides)
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(strides) + "} is transposed!");
MIGRAPHX_THROW("DOT: batch size {" + to_string_range(strides) + "} is transposed!");
}
}
......
......@@ -423,9 +423,7 @@ TEST_CASE(softmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
auto r = p.add_instruction(migraphx::op::reshape{{1, 3, 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{1, 3}}, s);
p.add_instruction(migraphx::op::softmax{1}, l0);
auto prog = migraphx::parse_onnx("softmax_test.onnx");
EXPECT(p == prog);
......
......@@ -436,10 +436,7 @@ TEST_CASE(softmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
auto dims = l0->get_shape().lens();
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s);
p.add_instruction(migraphx::op::softmax{1}, l0);
auto prog = optimize_tf("softmax_test.pb", false);
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment