Unverified Commit f006b0a9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into eval-check

parents b29c3283 24d68767
...@@ -29,9 +29,13 @@ struct unsqueeze ...@@ -29,9 +29,13 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
if(input_shape.scalar())
return shape{type, old_lens};
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0; std::size_t p = 0;
......
...@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins); ...@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins);
struct program struct program
{ {
program(); program();
// move constructor
program(program&&) noexcept; program(program&&) noexcept;
program& operator=(program&&) noexcept;
// copy constructor
program(const program&);
// copy assignment operator
program& operator=(program);
~program() noexcept; ~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>; using parameter_map = std::unordered_map<std::string, argument>;
...@@ -118,6 +126,9 @@ struct program ...@@ -118,6 +126,9 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
private:
void assign(const program& p);
private: private:
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
}; };
......
...@@ -1361,28 +1361,26 @@ struct onnx_parser ...@@ -1361,28 +1361,26 @@ struct onnx_parser
static literal parse_tensor(const onnx::TensorProto& t) static literal parse_tensor(const onnx::TensorProto& t)
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
{
dims = {1};
}
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return literal{{shape::float_type, dims}, s.data()}; case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT16: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::UINT16:
case onnx::TensorProto::INT16: return literal{{shape::int32_type, dims}, s.data()}; return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT32: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()}; case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::FLOAT16: return literal{{shape::half_type, dims}, s.data()}; case onnx::TensorProto::FLOAT16:
case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()}; return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
...@@ -1394,21 +1392,21 @@ struct onnx_parser ...@@ -1394,21 +1392,21 @@ struct onnx_parser
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: case onnx::TensorProto::FLOAT:
return literal{{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()}; return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT64: case onnx::TensorProto::INT64:
return literal{{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()}; return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
{ {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
...@@ -1417,11 +1415,10 @@ struct onnx_parser ...@@ -1417,11 +1415,10 @@ struct onnx_parser
data_uint16.end(), data_uint16.end(),
std::back_inserter(data_half), std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); }); [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half.begin(), data_half.end()}; return create_literal(shape::half_type, dims, data_half);
} }
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return create_literal(shape::double_type, dims, t.double_data());
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
...@@ -1430,6 +1427,23 @@ struct onnx_parser ...@@ -1430,6 +1427,23 @@ struct onnx_parser
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
static literal
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
{
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
{
if(dims.empty())
return literal{{shape_type}, data.begin(), data.end()};
return literal{{shape_type, dims}, data.begin(), data.end()};
}
static shape parse_type(const onnx::TypeProto& t) static shape parse_type(const onnx::TypeProto& t)
{ {
shape::type_t shape_type{}; shape::type_t shape_type{};
......
...@@ -86,8 +86,70 @@ static void print_program(const program& p, F print_func) ...@@ -86,8 +86,70 @@ static void print_program(const program& p, F print_func)
program::program() : impl(std::make_unique<program_impl>()) {} program::program() : impl(std::make_unique<program_impl>()) {}
program::program(program&&) noexcept = default; program::program(program&&) noexcept = default;
program& program::operator=(program&&) noexcept = default; program::~program() noexcept = default;
program::~program() noexcept = default;
// copy constructor
program::program(const program& p) { assign(p); }
// copy assignment operator
program& program::operator=(program p)
{
std::swap(p.impl, this->impl);
return *this;
}
void program::assign(const program& p)
{
// clean the current program
if(!impl)
{
impl = std::make_unique<program_impl>();
}
else if(!impl->instructions.empty())
{
impl->instructions.clear();
}
impl->ctx = p.impl->ctx;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p))
{
instruction_ref copy_ins{};
if(ins->name() == "@literal")
{
auto l = ins->get_literal();
copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l});
}
else if(ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(),
{builtin::param{name}, std::move(s), {}});
}
else if(ins->name() == "@outline")
{
auto s = ins->get_shape();
copy_ins =
impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
}
else
{
// retrieve its mapped input
auto inputs = ins->inputs();
// ensure all inputs have its corresponding copy instructions
assert(std::all_of(
inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; }));
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i];
});
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
ins_map[ins] = copy_ins;
}
}
instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args) instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args)
{ {
......
...@@ -772,7 +772,7 @@ template <typename Op> ...@@ -772,7 +772,7 @@ template <typename Op>
struct cpu_binary struct cpu_binary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return "cpu::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); } shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
......
...@@ -741,10 +741,6 @@ struct tf_parser ...@@ -741,10 +741,6 @@ struct tf_parser
static literal parse_tensor(const tensorflow::TensorProto& t) static literal parse_tensor(const tensorflow::TensorProto& t)
{ {
std::vector<size_t> dims = parse_dims(t.tensor_shape()); std::vector<size_t> dims = parse_dims(t.tensor_shape());
if(dims.empty())
{
dims = {1};
}
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()); size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
if(!t.tensor_content().empty()) // has raw data if(!t.tensor_content().empty()) // has raw data
{ {
...@@ -755,17 +751,17 @@ struct tf_parser ...@@ -755,17 +751,17 @@ struct tf_parser
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()}; return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::uint16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16: case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()}; return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()}; case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()}; return literal{{shape::double_type, dims}, s.data()};
...@@ -815,21 +811,23 @@ struct tf_parser ...@@ -815,21 +811,23 @@ struct tf_parser
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, get_data_vals(t.float_val(), shape_size)}; return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: case tensorflow::DataType::DT_INT8:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT16: case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, get_data_vals(t.int64_val(), shape_size)}; return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
return literal{{shape::int32_type, dims}, get_data_vals(t.bool_val(), shape_size)}; return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF: case tensorflow::DataType::DT_HALF:
{ {
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size); std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
...@@ -839,7 +837,7 @@ struct tf_parser ...@@ -839,7 +837,7 @@ struct tf_parser
data_uint16.end(), data_uint16.end(),
std::back_inserter(data_half), std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); }); [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half}; return create_literal(shape::half_type, dims, data_half);
} }
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)}; return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
...@@ -911,6 +909,16 @@ struct tf_parser ...@@ -911,6 +909,16 @@ struct tf_parser
[](tensorflow::TensorShapeProto_Dim dim) { return dim.size(); }); [](tensorflow::TensorShapeProto_Dim dim) { return dim.size(); });
return dims; return dims;
} }
template <class T>
static literal
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
{
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
}; };
program parse_tf(const std::string& name, bool is_nhwc) program parse_tf(const std::string& name, bool is_nhwc)
......
...@@ -699,8 +699,7 @@ TEST_CASE(add_scalar_test) ...@@ -699,8 +699,7 @@ TEST_CASE(add_scalar_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1); p.add_instruction(migraphx::op::add{}, m0, m1);
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/cpu/target.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -27,4 +31,78 @@ TEST_CASE(program_equality) ...@@ -27,4 +31,78 @@ TEST_CASE(program_equality)
EXPECT(x == y); EXPECT(x == y);
} }
TEST_CASE(program_copy)
{
auto create_program_1 = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5}};
std::vector<float> data(3 * 4 * 5);
std::iota(data.begin(), data.end(), 1.0f);
auto l2 = p.add_literal(migraphx::literal(s, data));
auto p1 = p.add_parameter("x", s);
auto po = p.add_outline(s);
auto sum = p.add_instruction(migraphx::op::add{}, l2, po);
p.add_instruction(migraphx::op::mul{}, sum, p1);
return p;
};
{
auto p1 = create_program_1();
migraphx::program p2{};
p2 = p1;
p2.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2(p1);
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2 = create_program();
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
migraphx::program p1;
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
migraphx::shape s3{migraphx::shape::float_type, {2, 6}};
auto para1 = p1.add_parameter("m1", s1);
auto para2 = p1.add_parameter("m2", s2);
auto para3 = p1.add_parameter("m3", s3);
p1.add_instruction(migraphx::op::dot{0.31f, 0.28f}, para1, para2, para3);
migraphx::program p2{};
p2 = p1;
EXPECT(p2 == p1);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p2 == p1);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -80,7 +80,7 @@ TEST_CASE(concat_test) ...@@ -80,7 +80,7 @@ TEST_CASE(concat_test)
int axis = 1; int axis = 1;
// tf uses axis as the third input, and it is in int32 format // tf uses axis as the third input, and it is in int32 format
// add the literal using a vector in order to set stride to 1 (like in tf parser) // add the literal using a vector in order to set stride to 1 (like in tf parser)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {1}}, std::vector<int>{axis}); p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1); p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false); auto prog = migraphx::parse_tf("concat_test.pb", false);
...@@ -91,7 +91,7 @@ TEST_CASE(concat_test) ...@@ -91,7 +91,7 @@ TEST_CASE(concat_test)
TEST_CASE(const_test) TEST_CASE(const_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type, {1}}, std::vector<float>{1.0f}); p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto prog = migraphx::parse_tf("constant_test.pb", false); auto prog = migraphx::parse_tf("constant_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment