Commit 78db8596 authored by Khalique's avatar Khalique
Browse files

manual merge

parents f5886e64 abf1b8e4
...@@ -66,8 +66,8 @@ struct onnx_parser ...@@ -66,8 +66,8 @@ struct onnx_parser
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("ArgMax", &onnx_parser::parse_argmax); add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
add_mem_op("ArgMin", &onnx_parser::parse_argmin); add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
add_mem_op("Cast", &onnx_parser::parse_cast); add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
...@@ -86,8 +86,8 @@ struct onnx_parser ...@@ -86,8 +86,8 @@ struct onnx_parser
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("MatMul", &onnx_parser::parse_matmul); add_mem_op("MatMul", &onnx_parser::parse_matmul);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax); add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax); add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze); add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("Slice", &onnx_parser::parse_slice);
...@@ -261,19 +261,10 @@ struct onnx_parser ...@@ -261,19 +261,10 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
instruction_ref template <class Op>
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args) instruction_ref parse_softmax(const std::string&,
{ const attribute_map& attributes,
auto dims = args.front()->get_shape().lens(); std::vector<instruction_ref> args)
auto r =
prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
auto s = prog.add_instruction(op::softmax{}, r);
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
}
instruction_ref parse_logsoftmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{ {
int axis = 1; int axis = 1;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
...@@ -281,10 +272,11 @@ struct onnx_parser ...@@ -281,10 +272,11 @@ struct onnx_parser
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
} }
return prog.add_instruction(op::logsoftmax{axis}, std::move(args)); return prog.add_instruction(Op{axis}, std::move(args));
} }
instruction_ref parse_argmax(const std::string&, template <class Op>
instruction_ref parse_arg_op(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
...@@ -302,39 +294,12 @@ struct onnx_parser ...@@ -302,39 +294,12 @@ struct onnx_parser
if(keep_dims == 0) if(keep_dims == 0)
{ {
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args)); auto ins = prog.add_instruction(Op{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins); return prog.add_instruction(op::squeeze{{axis}}, ins);
} }
else else
{ {
return prog.add_instruction(op::argmax{axis}, std::move(args)); return prog.add_instruction(Op{axis}, std::move(args));
}
}
instruction_ref parse_argmin(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmin{axis}, std::move(args));
} }
} }
...@@ -461,16 +426,8 @@ struct onnx_parser ...@@ -461,16 +426,8 @@ struct onnx_parser
op::reshape op; op::reshape op;
if(args.size() == 1) if(args.size() == 1)
{ {
if(contains(attributes, "shape")) literal s = parse_value(attributes.at("shape"));
{ s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
else
{
MIGRAPHX_THROW(
"Parse_reshape: shape attribute is needed when only one argument is provided!");
}
} }
if(args.size() == 2) if(args.size() == 2)
{ {
...@@ -863,7 +820,7 @@ struct onnx_parser ...@@ -863,7 +820,7 @@ struct onnx_parser
{ {
dtype = parse_value(attributes.at("dtype")).at<int>(); dtype = parse_value(attributes.at("dtype")).at<int>();
} }
migraphx::shape::type_t type = get_type(dtype); shape::type_t type = get_type(dtype);
if(contains(attributes, "input_as_shape")) if(contains(attributes, "input_as_shape"))
{ {
......
...@@ -178,7 +178,7 @@ void miopen_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) ...@@ -178,7 +178,7 @@ void miopen_gemm::batch_not_transposed(const std::vector<std::size_t>& strides)
return (i < j or i < matrix_size or j < matrix_size); return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end()) }) != batch.end())
{ {
MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(strides) + "} is transposed!"); MIGRAPHX_THROW("DOT: batch size {" + to_string_range(strides) + "} is transposed!");
} }
} }
......
...@@ -423,9 +423,7 @@ TEST_CASE(softmax_test) ...@@ -423,9 +423,7 @@ TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
auto r = p.add_instruction(migraphx::op::reshape{{1, 3, 1, 1}}, l0); p.add_instruction(migraphx::op::softmax{1}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{1, 3}}, s);
auto prog = migraphx::parse_onnx("softmax_test.onnx"); auto prog = migraphx::parse_onnx("softmax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
...@@ -435,11 +435,8 @@ TEST_CASE(slice_test) ...@@ -435,11 +435,8 @@ TEST_CASE(slice_test)
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
auto dims = l0->get_shape().lens(); p.add_instruction(migraphx::op::softmax{1}, l0);
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s);
auto prog = optimize_tf("softmax_test.pb", false); auto prog = optimize_tf("softmax_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment