Commit 0b4f2f8d authored by Khalique's avatar Khalique
Browse files

formatting

parent d88bc5fe
......@@ -29,9 +29,9 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
if(input_shape.scalar())
return shape{type, old_lens};
......
......@@ -1375,14 +1375,17 @@ struct onnx_parser
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT16:
return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::FLOAT16: return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE: return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::FLOAT16:
return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
......@@ -1394,21 +1397,28 @@ struct onnx_parser
{
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT:
return create_literal(shape::float_type, dims, t.float_data().begin(), t.float_data().end());
return create_literal(
shape::float_type, dims, t.float_data().begin(), t.float_data().end());
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8:
return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::UINT16:
return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::INT16:
return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::INT32:
return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::INT64:
return create_literal(shape::int64_type, dims, t.int64_data().begin(), t.int64_data().end());
return create_literal(
shape::int64_type, dims, t.int64_data().begin(), t.int64_data().end());
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL:
return create_literal(shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
case onnx::TensorProto::FLOAT16:
{
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
......@@ -1430,7 +1440,8 @@ struct onnx_parser
MIGRAPHX_THROW("Invalid tensor type");
}
static literal create_literal(shape::type_t shape_type, std::vector<size_t> dims, const char* data)
static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, const char* data)
{
if(dims.empty())
return literal{{shape_type, {1}, {0}}, data};
......@@ -1438,14 +1449,14 @@ struct onnx_parser
}
template <class Iterator>
static literal create_literal(shape::type_t shape_type, std::vector<size_t> dims, Iterator start, Iterator end)
static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, Iterator start, Iterator end)
{
if(dims.empty())
return literal{{shape_type, {1}, {0}}, start, end};
return literal{{shape_type, dims}, start, end};
}
static shape parse_type(const onnx::TypeProto& t)
{
shape::type_t shape_type{};
......
......@@ -736,7 +736,8 @@ struct tf_parser
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT:
return create_literal(shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
......@@ -747,7 +748,8 @@ struct tf_parser
case tensorflow::DataType::DT_INT32:
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT64:
return create_literal(shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
......@@ -834,13 +836,13 @@ struct tf_parser
}
template <class T>
static literal create_literal(shape::type_t shape_type, std::vector<size_t> dims, std::vector<T> data)
static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, std::vector<T> data)
{
if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
return literal{{shape_type, {1}, {0}}, data};
return literal{{shape_type, dims}, data};
}
};
program parse_tf(const std::string& name, bool is_nhwc)
......
......@@ -699,8 +699,8 @@ TEST_CASE(add_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1}});
auto l1 = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment