Commit 0b4f2f8d authored by Khalique's avatar Khalique
Browse files

formatting

parent d88bc5fe
...@@ -29,9 +29,9 @@ struct unsqueeze ...@@ -29,9 +29,9 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
if(input_shape.scalar()) if(input_shape.scalar())
return shape{type, old_lens}; return shape{type, old_lens};
......
...@@ -1375,14 +1375,17 @@ struct onnx_parser ...@@ -1375,14 +1375,17 @@ struct onnx_parser
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data()); case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data()); case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT16: return create_literal(shape::int32_type, dims, s.data()); case onnx::TensorProto::UINT16:
return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data()); case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data()); case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data()); case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data()); case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::FLOAT16: return create_literal(shape::half_type, dims, s.data()); case onnx::TensorProto::FLOAT16:
case onnx::TensorProto::DOUBLE: return create_literal(shape::double_type, dims, s.data()); return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
...@@ -1394,21 +1397,28 @@ struct onnx_parser ...@@ -1394,21 +1397,28 @@ struct onnx_parser
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: case onnx::TensorProto::FLOAT:
return create_literal(shape::float_type, dims, t.float_data().begin(), t.float_data().end()); return create_literal(
shape::float_type, dims, t.float_data().begin(), t.float_data().end());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end()); return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end()); return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end()); return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end()); return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::INT64: case onnx::TensorProto::INT64:
return create_literal(shape::int64_type, dims, t.int64_data().begin(), t.int64_data().end()); return create_literal(
shape::int64_type, dims, t.int64_data().begin(), t.int64_data().end());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end()); return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
{ {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
...@@ -1430,7 +1440,8 @@ struct onnx_parser ...@@ -1430,7 +1440,8 @@ struct onnx_parser
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
static literal create_literal(shape::type_t shape_type, std::vector<size_t> dims, const char* data) static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, const char* data)
{ {
if(dims.empty()) if(dims.empty())
return literal{{shape_type, {1}, {0}}, data}; return literal{{shape_type, {1}, {0}}, data};
...@@ -1438,14 +1449,14 @@ struct onnx_parser ...@@ -1438,14 +1449,14 @@ struct onnx_parser
} }
template <class Iterator> template <class Iterator>
static literal create_literal(shape::type_t shape_type, std::vector<size_t> dims, Iterator start, Iterator end) static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, Iterator start, Iterator end)
{ {
if(dims.empty()) if(dims.empty())
return literal{{shape_type, {1}, {0}}, start, end}; return literal{{shape_type, {1}, {0}}, start, end};
return literal{{shape_type, dims}, start, end}; return literal{{shape_type, dims}, start, end};
} }
static shape parse_type(const onnx::TypeProto& t) static shape parse_type(const onnx::TypeProto& t)
{ {
shape::type_t shape_type{}; shape::type_t shape_type{};
......
...@@ -736,7 +736,8 @@ struct tf_parser ...@@ -736,7 +736,8 @@ struct tf_parser
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return create_literal(shape::float_type, dims, get_data_vals(t.float_val(), shape_size)); return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: case tensorflow::DataType::DT_INT8:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size)); return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
...@@ -747,7 +748,8 @@ struct tf_parser ...@@ -747,7 +748,8 @@ struct tf_parser
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size)); return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return create_literal(shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size)); return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size)); return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
...@@ -834,13 +836,13 @@ struct tf_parser ...@@ -834,13 +836,13 @@ struct tf_parser
} }
template <class T> template <class T>
static literal create_literal(shape::type_t shape_type, std::vector<size_t> dims, std::vector<T> data) static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, std::vector<T> data)
{ {
if(dims.empty() or (dims.size() == 1 and dims.front() == 1)) if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
return literal{{shape_type, {1}, {0}}, data}; return literal{{shape_type, {1}, {0}}, data};
return literal{{shape_type, dims}, data}; return literal{{shape_type, dims}, data};
} }
}; };
program parse_tf(const std::string& name, bool is_nhwc) program parse_tf(const std::string& name, bool is_nhwc)
......
...@@ -699,8 +699,8 @@ TEST_CASE(add_scalar_test) ...@@ -699,8 +699,8 @@ TEST_CASE(add_scalar_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = auto l1 = p.add_literal(
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1); p.add_instruction(migraphx::op::add{}, m0, m1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment