Commit fdd4e403 authored by Khalique's avatar Khalique
Browse files

fix shape constructor and data types

parent 77ef0c1d
......@@ -58,7 +58,7 @@ struct squeeze
if(new_lens.empty())
{
return shape{type, {1}, {0}};
return shape{type};
}
else
{
......
......@@ -1432,7 +1432,7 @@ struct onnx_parser
{
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
return literal{{shape_type, {1}, {0}}, data};
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
......@@ -1440,7 +1440,7 @@ struct onnx_parser
static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
{
if(dims.empty())
return literal{{shape_type, {1}, {0}}, data.begin(), data.end()};
return literal{{shape_type}, data.begin(), data.end()};
return literal{{shape_type, dims}, data.begin(), data.end()};
}
......
......@@ -751,17 +751,17 @@ struct tf_parser
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, s.data()};
return literal{{shape::uint16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, s.data()};
return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()};
......@@ -815,11 +815,11 @@ struct tf_parser
shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT16:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT32:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT64:
......@@ -916,7 +916,7 @@ struct tf_parser
{
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
return literal{{shape_type, {1}, {0}}, data};
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
};
......
......@@ -80,7 +80,7 @@ TEST_CASE(concat_test)
int axis = 1;
// tf uses axis as the third input, and it is in int32 format
// add the literal using a vector in order to set stride to 1 (like in tf parser)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {1}, {0}}, std::vector<int>{axis});
p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false);
......@@ -91,7 +91,7 @@ TEST_CASE(concat_test)
TEST_CASE(const_test)
{
migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type, {1}, {0}}, std::vector<float>{1.0f});
p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto prog = migraphx::parse_tf("constant_test.pb", false);
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment