Commit fdd4e403 authored by Khalique's avatar Khalique
Browse files

fix shape constructor and data types

parent 77ef0c1d
...@@ -58,7 +58,7 @@ struct squeeze ...@@ -58,7 +58,7 @@ struct squeeze
if(new_lens.empty()) if(new_lens.empty())
{ {
return shape{type, {1}, {0}}; return shape{type};
} }
else else
{ {
......
...@@ -1432,7 +1432,7 @@ struct onnx_parser ...@@ -1432,7 +1432,7 @@ struct onnx_parser
{ {
// in case of scalar constants in onnx file, use dims=1 to fill initializer data // in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty()) if(dims.empty())
return literal{{shape_type, {1}, {0}}, data}; return literal{{shape_type}, data};
return literal{{shape_type, dims}, data}; return literal{{shape_type, dims}, data};
} }
...@@ -1440,7 +1440,7 @@ struct onnx_parser ...@@ -1440,7 +1440,7 @@ struct onnx_parser
static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data) static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
{ {
if(dims.empty()) if(dims.empty())
return literal{{shape_type, {1}, {0}}, data.begin(), data.end()}; return literal{{shape_type}, data.begin(), data.end()};
return literal{{shape_type, dims}, data.begin(), data.end()}; return literal{{shape_type, dims}, data.begin(), data.end()};
} }
......
...@@ -751,17 +751,17 @@ struct tf_parser ...@@ -751,17 +751,17 @@ struct tf_parser
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()}; return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::uint16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16: case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()}; return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()}; case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()}; return literal{{shape::double_type, dims}, s.data()};
...@@ -815,11 +815,11 @@ struct tf_parser ...@@ -815,11 +815,11 @@ struct tf_parser
shape::float_type, dims, get_data_vals(t.float_val(), shape_size)); shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: case tensorflow::DataType::DT_INT8:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size)); return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size)); return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT16: case tensorflow::DataType::DT_INT16:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size)); return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size)); return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
...@@ -916,7 +916,7 @@ struct tf_parser ...@@ -916,7 +916,7 @@ struct tf_parser
{ {
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar // assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if(dims.empty() or (dims.size() == 1 and dims.front() == 1)) if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
return literal{{shape_type, {1}, {0}}, data}; return literal{{shape_type}, data};
return literal{{shape_type, dims}, data}; return literal{{shape_type, dims}, data};
} }
}; };
......
...@@ -80,7 +80,7 @@ TEST_CASE(concat_test) ...@@ -80,7 +80,7 @@ TEST_CASE(concat_test)
int axis = 1; int axis = 1;
// tf uses axis as the third input, and it is in int32 format // tf uses axis as the third input, and it is in int32 format
// add the literal using a vector in order to set stride to 1 (like in tf parser) // add the literal using a vector in order to set stride to 1 (like in tf parser)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {1}, {0}}, std::vector<int>{axis}); p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1); p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false); auto prog = migraphx::parse_tf("concat_test.pb", false);
...@@ -91,7 +91,7 @@ TEST_CASE(concat_test) ...@@ -91,7 +91,7 @@ TEST_CASE(concat_test)
TEST_CASE(const_test) TEST_CASE(const_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type, {1}, {0}}, std::vector<float>{1.0f}); p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto prog = migraphx::parse_tf("constant_test.pb", false); auto prog = migraphx::parse_tf("constant_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment