Commit cb2cd068 authored by charlie's avatar charlie
Browse files

Draft partial fix

parent 2fac6fa5
......@@ -45,6 +45,11 @@ struct literal : raw_data<literal>
{
literal() {}
/*!
* Empty literal with a specific shape type
*/
explicit literal(shape::type_t shape_type) : m_shape(shape_type, {}) {}
template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}>
literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType)
{
......
......@@ -60,7 +60,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0)
{
return {};
return literal{shape_type};
}
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
......@@ -77,7 +77,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0)
{
return {};
return literal{shape_type};
}
// scalar input
......
......@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal
if(v.get_shape().elements() == 0)
{
return info.add_literal(literal{});
return info.add_literal(literal{v.get_shape().type()});
}
auto dim_size = info.attributes.at("value").t().dims_size();
......
......@@ -640,7 +640,7 @@ TEST_CASE(constant_scalar_test2)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {}}, {0}});
mm->add_literal(migraphx::literal{migraphx::shape::int64_type});
auto prog = optimize_onnx("constant_scalar_test2.onnx");
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment