Commit 8819f4dc authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into bert_ops

parents 708c0401 7f28f161
...@@ -81,6 +81,7 @@ rocm_enable_clang_tidy( ...@@ -81,6 +81,7 @@ rocm_enable_clang_tidy(
-modernize-use-override -modernize-use-override
-modernize-pass-by-value -modernize-pass-by-value
-modernize-use-default-member-init -modernize-use-default-member-init
-modernize-use-trailing-return-type
-modernize-use-transparent-functors -modernize-use-transparent-functors
-performance-type-promotion-in-math-fn -performance-type-promotion-in-math-fn
-readability-braces-around-statements -readability-braces-around-statements
......
...@@ -20,6 +20,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -20,6 +20,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
clang-format-5.0 \ clang-format-5.0 \
clang-tidy-5.0 \ clang-tidy-5.0 \
cmake \ cmake \
comgr \
curl \ curl \
doxygen \ doxygen \
g++-7 \ g++-7 \
...@@ -32,14 +33,16 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -32,14 +33,16 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libncurses5-dev \ libncurses5-dev \
libnuma-dev \ libnuma-dev \
libpthread-stubs0-dev \ libpthread-stubs0-dev \
libssl-dev \
python \ python \
python-dev \ python-dev \
python-pip \ python-pip \
rocm-device-libs \
rocm-opencl \ rocm-opencl \
rocm-opencl-dev \ rocm-opencl-dev \
rocminfo \
software-properties-common \ software-properties-common \
wget && \ wget \
zlib1g-dev && \
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
...@@ -50,7 +53,7 @@ RUN pip install cget ...@@ -50,7 +53,7 @@ RUN pip install cget
RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz
# Install hcc # Install hcc
RUN rclone -b roc-2.3.x -c fd93baed7dcc4fe8019b5fdc90213bfe7c298245 https://github.com/RadeonOpenCompute/hcc.git /hcc RUN rclone -b roc-2.6.x -c 0f4c96b7851af2663a7f3ac16ecfb76c7c78a5bf https://github.com/RadeonOpenCompute/hcc.git /hcc
RUN cget -p $PREFIX install hcc,/hcc RUN cget -p $PREFIX install hcc,/hcc
# Use hcc # Use hcc
......
pfultz2/rocm-recipes pfultz2/rocm-recipes
danmar/cppcheck@8aa68ee297c2d9ebadf5bcfd00c66ea8d9291e35 -DHAVE_RULES=1 danmar/cppcheck@8aa68ee297c2d9ebadf5bcfd00c66ea8d9291e35 -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@e21df3058728ad8e73708bc99b82e0bdd3509c97 ROCm-Developer-Tools/HIP@2490e42baa7d90458f0632fd9fbead2d395f41b9
python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36 python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
google/protobuf@v3.8.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off google/protobuf@v3.8.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
RadeonOpenCompute/rocm-cmake@42f6740 --build RadeonOpenCompute/rocm-cmake@42f6740 --build
ROCmSoftwarePlatform/rocBLAS@30a992ae02fda568688bcd190edd5e277d6674d9 ROCmSoftwarePlatform/rocBLAS@7197df74e5a1ba64ff967065872e5f86a3516637
ROCmSoftwarePlatform/MIOpen@1.7.0 ROCmSoftwarePlatform/MIOpen@2.0.0
blaze,https://bitbucket.org/blaze-lib/blaze/get/f0755dea0e03.tar.gz -X header -DHEADER_DIR=blaze blaze,https://bitbucket.org/blaze-lib/blaze/get/f0755dea0e03.tar.gz -X header -DHEADER_DIR=blaze
half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969 half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969
pybind/pybind11@v2.2.4 -DPYBIND11_TEST=Off --build pybind/pybind11@v2.2.4 -DPYBIND11_TEST=Off --build
...@@ -68,11 +68,13 @@ struct onnx_parser ...@@ -68,11 +68,13 @@ struct onnx_parser
add_mem_op("ArgMax", &onnx_parser::parse_argmax); add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin); add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu); add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
...@@ -93,6 +95,7 @@ struct onnx_parser ...@@ -93,6 +95,7 @@ struct onnx_parser
add_mem_op("Gather", &onnx_parser::parse_gather); add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape); add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn); add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru); add_mem_op("GRU", &onnx_parser::parse_gru);
...@@ -464,8 +467,7 @@ struct onnx_parser ...@@ -464,8 +467,7 @@ struct onnx_parser
if(args.size() == 2) if(args.size() == 2)
{ {
auto s = args[1]->eval(); auto s = args[1]->eval();
if(s.empty()) check_arg_empty(s, "Reshape: dynamic shape is not supported");
MIGRAPHX_THROW("Dynamic shape is not supported.");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
...@@ -544,7 +546,13 @@ struct onnx_parser ...@@ -544,7 +546,13 @@ struct onnx_parser
attribute_map attributes, attribute_map attributes,
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_value(attributes.at("value")); literal v = parse_value(attributes.at("value"));
// return empty literal
if(v.get_shape().elements() == 0)
{
return prog.add_literal(literal{});
}
auto dim_size = attributes.at("value").t().dims_size(); auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(dim_size == 0)
...@@ -872,10 +880,7 @@ struct onnx_parser ...@@ -872,10 +880,7 @@ struct onnx_parser
} }
migraphx::argument in = args[0]->eval(); migraphx::argument in = args[0]->eval();
if(in.empty()) check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
...@@ -903,6 +908,74 @@ struct onnx_parser ...@@ -903,6 +908,74 @@ struct onnx_parser
} }
} }
instruction_ref parse_constant_of_shape(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
literal l_val{};
if(contains(attributes, "value"))
{
l_val = parse_value(attributes.at("value"));
if(l_val.get_shape().elements() != 1)
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
}
}
else
{
l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
}
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if(args.empty())
{
MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
}
else
{
migraphx::shape s;
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), *val.begin());
l_out = literal(s, out_vec);
});
return prog.add_literal(l_out);
}
}
instruction_ref
parse_expand(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
}
std::vector<instruction_ref> std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -1325,6 +1398,19 @@ struct onnx_parser ...@@ -1325,6 +1398,19 @@ struct onnx_parser
} }
} }
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
if(!contains(attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
int to_type = parse_value(attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return prog.add_instruction(op::convert{type}, std::move(args));
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
...@@ -1471,16 +1557,16 @@ struct onnx_parser ...@@ -1471,16 +1557,16 @@ struct onnx_parser
{ {
switch(attr.type()) switch(attr.type())
{ {
case onnx::AttributeProto::UNDEFINED: return {};
case onnx::AttributeProto::FLOAT: return literal{attr.f()}; case onnx::AttributeProto::FLOAT: return literal{attr.f()};
case onnx::AttributeProto::INT: return literal{attr.i()}; case onnx::AttributeProto::INT: return literal{attr.i()};
case onnx::AttributeProto::STRING: return {};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t()); case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats()); case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints()); case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
case onnx::AttributeProto::STRINGS: return {}; case onnx::AttributeProto::UNDEFINED:
case onnx::AttributeProto::TENSORS: return {}; case onnx::AttributeProto::GRAPH:
case onnx::AttributeProto::STRING:
case onnx::AttributeProto::STRINGS:
case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
} }
MIGRAPHX_THROW("Invalid attribute type"); MIGRAPHX_THROW("Invalid attribute type");
...@@ -1494,47 +1580,41 @@ struct onnx_parser ...@@ -1494,47 +1580,41 @@ struct onnx_parser
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data()); case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT16:
return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
return create_literal(shape::half_type, dims, s.data()); return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, s.data()); return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::INT8:
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::UINT16:
case onnx::TensorProto::INT16:
case onnx::TensorProto::INT32:
case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT8:
case onnx::TensorProto::STRING:
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::UINT32:
case onnx::TensorProto::UINT64:
case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT:
return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
case onnx::TensorProto::BOOL:
return create_literal(shape::int32_type, dims, t.int32_data()); return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT64: case onnx::TensorProto::INT64:
return create_literal(shape::int64_type, dims, t.int64_data()); return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::DOUBLE:
case onnx::TensorProto::BOOL: return create_literal(shape::double_type, dims, t.double_data());
return create_literal(shape::int32_type, dims, t.int32_data()); case onnx::TensorProto::FLOAT:
return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
{ {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
...@@ -1545,11 +1625,12 @@ struct onnx_parser ...@@ -1545,11 +1625,12 @@ struct onnx_parser
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); }); [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half); return create_literal(shape::half_type, dims, data_half);
} }
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::UNDEFINED:
return create_literal(shape::double_type, dims, t.double_data()); case onnx::TensorProto::UINT8:
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::STRING:
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT32:
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::UINT64:
case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
...@@ -1577,28 +1658,23 @@ struct onnx_parser ...@@ -1577,28 +1658,23 @@ struct onnx_parser
shape::type_t shape_type{}; shape::type_t shape_type{};
switch(t.tensor_type().elem_type()) switch(t.tensor_type().elem_type())
{ {
case onnx::TensorProto::UNDEFINED:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break; case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
case onnx::TensorProto::UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case onnx::TensorProto::INT8: shape_type = shape::int8_type; break; case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break; case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
case onnx::TensorProto::INT16: shape_type = shape::int16_type; break; case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
case onnx::TensorProto::INT32: shape_type = shape::int32_type; break; case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
case onnx::TensorProto::INT64: shape_type = shape::int64_type; break; case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
case onnx::TensorProto::STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break; case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break; case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break; case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break; case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
case onnx::TensorProto::UINT8:
case onnx::TensorProto::STRING:
case onnx::TensorProto::BOOL:
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::COMPLEX64: case onnx::TensorProto::COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case onnx::TensorProto::COMPLEX128: case onnx::TensorProto::COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128"); break; // throw std::runtime_error("Unsupported type");
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
...@@ -1637,6 +1713,14 @@ struct onnx_parser ...@@ -1637,6 +1713,14 @@ struct onnx_parser
} }
} }
} }
void check_arg_empty(const argument& arg, const std::string& msg)
{
if(arg.empty())
{
MIGRAPHX_THROW(msg);
}
}
}; };
program parse_onnx(const std::string& name) program parse_onnx(const std::string& name)
......
...@@ -8,6 +8,7 @@ namespace device { ...@@ -8,6 +8,7 @@ namespace device {
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg) void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
{ {
reduce(stream, result, arg, sum{}, 0, id{}, id{}); reduce(stream, result, arg, sum{}, 0, id{}, id{});
} }
......
...@@ -1004,72 +1004,56 @@ struct tf_parser ...@@ -1004,72 +1004,56 @@ struct tf_parser
shape::type_t shape_type{}; shape::type_t shape_type{};
switch(t) switch(t)
{ {
case tensorflow::DataType::DT_INVALID:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break; case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break; case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break; case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
case tensorflow::DataType::DT_UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break; case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break; case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING: case tensorflow::DataType::DT_STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case tensorflow::DataType::DT_COMPLEX64: case tensorflow::DataType::DT_COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case tensorflow::DataType::DT_QINT8: case tensorflow::DataType::DT_QINT8:
break; // throw std::runtime_error("Unsupported type QINT8");
case tensorflow::DataType::DT_QUINT8: case tensorflow::DataType::DT_QUINT8:
break; // throw std::runtime_error("Unsupported type QUINT8");
case tensorflow::DataType::DT_QINT32: case tensorflow::DataType::DT_QINT32:
break; // throw std::runtime_error("Unsupported type QINT32");
case tensorflow::DataType::DT_BFLOAT16: case tensorflow::DataType::DT_BFLOAT16:
break; // throw std::runtime_error("Unsupported type BFLOAT16");
case tensorflow::DataType::DT_QINT16: case tensorflow::DataType::DT_QINT16:
break; // throw std::runtime_error("Unsupported type QINT16");
case tensorflow::DataType::DT_QUINT16: case tensorflow::DataType::DT_QUINT16:
break; // throw std::runtime_error("Unsupported type QUINT16");
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_COMPLEX128: case tensorflow::DataType::DT_COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_RESOURCE: case tensorflow::DataType::DT_RESOURCE:
break; // throw std::runtime_error("Unsupported type RESOURCE");
case tensorflow::DataType::DT_VARIANT: case tensorflow::DataType::DT_VARIANT:
break; // throw std::runtime_error("Unsupported type VARIANT");
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64:
shape_type = shape::uint64_type;
break;
// tf pb should not use these types // tf pb should not use these types
case tensorflow::DataType::DT_FLOAT_REF: break; case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF: break; case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF: break; case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF: break; case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF: break; case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF: break; case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF: break; case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF: break; case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF: break; case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF: break; case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF: break; case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF: break; case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF: break; case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF: break; case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF: break; case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF: break; case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF: break; case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF: break; case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF: break; case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF: break; case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF: break; case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF: break; case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF: break; case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: break; case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break; case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
} }
return shape_type; return shape_type;
...@@ -1084,61 +1068,59 @@ struct tf_parser ...@@ -1084,61 +1068,59 @@ struct tf_parser
const std::string& s = t.tensor_content(); const std::string& s = t.tensor_content();
switch(t.dtype()) switch(t.dtype())
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()}; return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_BOOL:
case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()}; case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
return literal{{shape::uint16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16: case tensorflow::DataType::DT_INT16:
return literal{{shape::int16_type, dims}, s.data()}; return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()}; return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()}; case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()}; return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_QINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_QINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error(""); throw std::runtime_error("");
} }
...@@ -1146,11 +1128,9 @@ struct tf_parser ...@@ -1146,11 +1128,9 @@ struct tf_parser
} }
switch(t.dtype()) switch(t.dtype())
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return create_literal( return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size)); shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: case tensorflow::DataType::DT_INT8:
return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size)); return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
...@@ -1162,7 +1142,6 @@ struct tf_parser ...@@ -1162,7 +1142,6 @@ struct tf_parser
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return create_literal( return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size)); shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size)); return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF: case tensorflow::DataType::DT_HALF:
...@@ -1178,43 +1157,45 @@ struct tf_parser ...@@ -1178,43 +1157,45 @@ struct tf_parser
} }
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)}; return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_QINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_QINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error(""); throw std::runtime_error("");
} }
......
 cast-example:F

xy"Cast*
to test_castZ
x



b
y


B
constant-of-shape:
6shape"Constant*#
value**B shape_tensor 
7
shapey"ConstantOfShape*
value*:
Bvalue constant_of_shapeb
y



B
constant-of-shape:
6shape"Constant*#
value**B shape_tensor 

shapey"ConstantOfShapeconstant_of_shapeb
y



B
expand:
7shape"Constant*$
value**B shape_tensor

x
shapey"ExpandexpandZ
x



b
y




B
...@@ -925,4 +925,78 @@ TEST_CASE(pow_test) ...@@ -925,4 +925,78 @@ TEST_CASE(pow_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(cast_test)
{
migraphx::program p;
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}});
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l);
auto prog = migraphx::parse_onnx("cast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_float)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 10.0f);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape1.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_int64)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4});
std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape2.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_no_value_attr)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 0.0f);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape3.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_empty_input)
{
migraphx::program p;
p.add_literal(migraphx::literal());
migraphx::shape s(migraphx::shape::int64_type, {1}, {0});
std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape4.onnx");
EXPECT(p == prog);
}
TEST_CASE(expand_test)
{
migraphx::program p;
migraphx::shape s(migraphx::shape::float_type, {3, 1, 1});
auto param = p.add_parameter("x", s);
migraphx::shape ss(migraphx::shape::int32_type, {4});
p.add_literal(migraphx::literal(ss, {2, 3, 4, 5}));
p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param);
auto prog = migraphx::parse_onnx("expand_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment