Commit d88bc5fe authored by Khalique's avatar Khalique
Browse files

adjusted parsing for scalars, squeeze-unsqueeze semantics for scalars

parent af00eea8
...@@ -58,7 +58,7 @@ struct squeeze ...@@ -58,7 +58,7 @@ struct squeeze
if(new_lens.empty()) if(new_lens.empty())
{ {
return shape{type}; return shape{type, {1}, {0}};
} }
else else
{ {
......
...@@ -32,6 +32,10 @@ struct unsqueeze ...@@ -32,6 +32,10 @@ struct unsqueeze
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
if(input_shape.scalar())
return shape{type, old_lens};
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0; std::size_t p = 0;
......
...@@ -1362,27 +1362,27 @@ struct onnx_parser ...@@ -1362,27 +1362,27 @@ struct onnx_parser
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data // in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty()) // if(dims.empty())
{ // {
dims = {1}; // dims = {1};
} // }
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return literal{{shape::float_type, dims}, s.data()}; case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT16: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::UINT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT16: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT32: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()}; case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::FLOAT16: return literal{{shape::half_type, dims}, s.data()}; case onnx::TensorProto::FLOAT16: return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()}; case onnx::TensorProto::DOUBLE: return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
...@@ -1394,21 +1394,21 @@ struct onnx_parser ...@@ -1394,21 +1394,21 @@ struct onnx_parser
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: case onnx::TensorProto::FLOAT:
return literal{{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()}; return create_literal(shape::float_type, dims, t.float_data().begin(), t.float_data().end());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::INT64: case onnx::TensorProto::INT64:
return literal{{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()}; return create_literal(shape::int64_type, dims, t.int64_data().begin(), t.int64_data().end());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
{ {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
...@@ -1417,11 +1417,11 @@ struct onnx_parser ...@@ -1417,11 +1417,11 @@ struct onnx_parser
data_uint16.end(), data_uint16.end(),
std::back_inserter(data_half), std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); }); [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half.begin(), data_half.end()}; return create_literal(shape::half_type, dims, data_half.begin(), data_half.end());
} }
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return create_literal(
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()}; shape::double_type, dims, t.double_data().begin(), t.double_data().end());
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
...@@ -1430,6 +1430,22 @@ struct onnx_parser ...@@ -1430,6 +1430,22 @@ struct onnx_parser
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
static literal create_literal(shape::type_t shape_type, std::vector<size_t> dims, const char* data)
{
if(dims.empty())
return literal{{shape_type, {1}, {0}}, data};
return literal{{shape_type, dims}, data};
}
template <class Iterator>
static literal create_literal(shape::type_t shape_type, std::vector<size_t> dims, Iterator start, Iterator end)
{
if(dims.empty())
return literal{{shape_type, {1}, {0}}, start, end};
return literal{{shape_type, dims}, start, end};
}
static shape parse_type(const onnx::TypeProto& t) static shape parse_type(const onnx::TypeProto& t)
{ {
shape::type_t shape_type{}; shape::type_t shape_type{};
......
...@@ -662,10 +662,10 @@ struct tf_parser ...@@ -662,10 +662,10 @@ struct tf_parser
static literal parse_tensor(const tensorflow::TensorProto& t) static literal parse_tensor(const tensorflow::TensorProto& t)
{ {
std::vector<size_t> dims = parse_dims(t.tensor_shape()); std::vector<size_t> dims = parse_dims(t.tensor_shape());
if(dims.empty()) // if(dims.empty())
{ // {
dims = {1}; // dims = {1};
} // }
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()); size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
if(!t.tensor_content().empty()) // has raw data if(!t.tensor_content().empty()) // has raw data
{ {
...@@ -736,21 +736,21 @@ struct tf_parser ...@@ -736,21 +736,21 @@ struct tf_parser
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, get_data_vals(t.float_val(), shape_size)}; return create_literal(shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: case tensorflow::DataType::DT_INT8:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT16: case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, get_data_vals(t.int64_val(), shape_size)}; return create_literal(shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
return literal{{shape::int32_type, dims}, get_data_vals(t.bool_val(), shape_size)}; return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF: case tensorflow::DataType::DT_HALF:
{ {
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size); std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
...@@ -760,7 +760,7 @@ struct tf_parser ...@@ -760,7 +760,7 @@ struct tf_parser
data_uint16.end(), data_uint16.end(),
std::back_inserter(data_half), std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); }); [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half}; return create_literal(shape::half_type, dims, data_half);
} }
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)}; return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
...@@ -832,6 +832,15 @@ struct tf_parser ...@@ -832,6 +832,15 @@ struct tf_parser
[](tensorflow::TensorShapeProto_Dim dim) { return dim.size(); }); [](tensorflow::TensorShapeProto_Dim dim) { return dim.size(); });
return dims; return dims;
} }
template <class T>
static literal create_literal(shape::type_t shape_type, std::vector<size_t> dims, std::vector<T> data)
{
if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
return literal{{shape_type, {1}, {0}}, data};
return literal{{shape_type, dims}, data};
}
}; };
program parse_tf(const std::string& name, bool is_nhwc) program parse_tf(const std::string& name, bool is_nhwc)
......
...@@ -700,7 +700,7 @@ TEST_CASE(add_scalar_test) ...@@ -700,7 +700,7 @@ TEST_CASE(add_scalar_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1); p.add_instruction(migraphx::op::add{}, m0, m1);
......
...@@ -80,7 +80,7 @@ TEST_CASE(concat_test) ...@@ -80,7 +80,7 @@ TEST_CASE(concat_test)
int axis = 1; int axis = 1;
// tf uses axis as the third input, and it is in int32 format // tf uses axis as the third input, and it is in int32 format
// add the literal using a vector in order to set stride to 1 (like in tf parser) // add the literal using a vector in order to set stride to 1 (like in tf parser)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {1}}, std::vector<int>{axis}); p.add_literal(migraphx::shape{migraphx::shape::int32_type, {1}, {0}}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1); p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false); auto prog = migraphx::parse_tf("concat_test.pb", false);
...@@ -91,7 +91,7 @@ TEST_CASE(concat_test) ...@@ -91,7 +91,7 @@ TEST_CASE(concat_test)
TEST_CASE(const_test) TEST_CASE(const_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type, {1}}, std::vector<float>{1.0f}); p.add_literal(migraphx::shape{migraphx::shape::float_type, {1}, {0}}, std::vector<float>{1.0f});
auto prog = migraphx::parse_tf("constant_test.pb", false); auto prog = migraphx::parse_tf("constant_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment