"vscode:/vscode.git/clone" did not exist on "1656115917ea5d65fe1f6a572f313b3afc3e8b6c"
Commit cb2cd068 authored by charlie's avatar charlie
Browse files

Draft partial fix

parent 2fac6fa5
...@@ -45,6 +45,11 @@ struct literal : raw_data<literal> ...@@ -45,6 +45,11 @@ struct literal : raw_data<literal>
{ {
literal() {} literal() {}
/*!
* Empty literal with a specific shape type
*/
explicit literal(shape::type_t shape_type) : m_shape(shape_type, {}) {}
template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}> template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}>
literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType) literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType)
{ {
......
...@@ -60,7 +60,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const ...@@ -60,7 +60,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>()); std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0) if(elem_num == 0)
{ {
return {}; return literal{shape_type};
} }
// in case of scalar constants in onnx file, use dims=1 to fill initializer data // in case of scalar constants in onnx file, use dims=1 to fill initializer data
...@@ -77,7 +77,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t ...@@ -77,7 +77,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>()); std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0) if(elem_num == 0)
{ {
return {}; return literal{shape_type};
} }
// scalar input // scalar input
......
...@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant> ...@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{}); return info.add_literal(literal{v.get_shape().type()});
} }
auto dim_size = info.attributes.at("value").t().dims_size(); auto dim_size = info.attributes.at("value").t().dims_size();
......
...@@ -640,7 +640,7 @@ TEST_CASE(constant_scalar_test2) ...@@ -640,7 +640,7 @@ TEST_CASE(constant_scalar_test2)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {}}, {0}}); mm->add_literal(migraphx::literal{migraphx::shape::int64_type});
auto prog = optimize_onnx("constant_scalar_test2.onnx"); auto prog = optimize_onnx("constant_scalar_test2.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment