Commit 15385fb1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents f7f02979 b606ed4f
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/constant_propagate.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
...@@ -47,7 +48,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -47,7 +48,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
//dead_code_elimination{}, //dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
constant_propagate{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
//simplify_reshapes{}, //simplify_reshapes{},
...@@ -57,6 +58,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -57,6 +58,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{}, dead_code_elimination{},
eliminate_contiguous{}, eliminate_contiguous{},
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{},
dead_code_elimination{},
fuse_ops{&ctx}, fuse_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
...@@ -110,6 +110,7 @@ struct tf_parser ...@@ -110,6 +110,7 @@ struct tf_parser
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{});
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
...@@ -117,12 +118,15 @@ struct tf_parser ...@@ -117,12 +118,15 @@ struct tf_parser
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pack", &tf_parser::parse_pack);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape); add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze); add_mem_op("Squeeze", &tf_parser::parse_squeeze);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
} }
template <class F> template <class F>
...@@ -149,7 +153,7 @@ struct tf_parser ...@@ -149,7 +153,7 @@ struct tf_parser
template <class T> template <class T>
void add_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { add_op(name, [this, x](const attribute_map& attributes, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
auto l0 = args[1]; auto l0 = args[1];
...@@ -211,7 +215,7 @@ struct tf_parser ...@@ -211,7 +215,7 @@ struct tf_parser
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
...@@ -234,7 +238,7 @@ struct tf_parser ...@@ -234,7 +238,7 @@ struct tf_parser
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel) uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
return prog.add_instruction(op::add{}, args[0], l0); return prog.add_instruction(op::add{}, args[0], l0);
} }
...@@ -335,6 +339,32 @@ struct tf_parser ...@@ -335,6 +339,32 @@ struct tf_parser
return prog.add_instruction(op, {args[0], weights}); return prog.add_instruction(op, {args[0], weights});
} }
instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
bool transa = false;
bool transb = false;
if(contains(attributes, "transpose_a"))
{
transa = attributes.at("transpose_a").b();
}
if(contains(attributes, "transpose_b"))
{
transb = attributes.at("transpose_a").b();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
return prog.add_instruction(op::dot{}, l1, l2);
}
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -353,6 +383,33 @@ struct tf_parser ...@@ -353,6 +383,33 @@ struct tf_parser
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation"); MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
} }
instruction_ref parse_pack(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
// reinterpret as unsqueeze with concat
std::vector<instruction_ref> unsqueezed_args;
int64_t axis = 0;
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
size_t input_size = args.front()->get_shape().lens().size();
if(axis > input_size)
{
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size));
}
// check if input arg needs axis to be converted to NCHW
if(input_size >= 4)
axis = parse_axis(axis);
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
}
instruction_ref instruction_ref
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
...@@ -480,6 +537,46 @@ struct tf_parser ...@@ -480,6 +537,46 @@ struct tf_parser
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
} }
instruction_ref parse_stridedslice(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size();
if(num_axes >= 4)
{
reorder_data(starts);
reorder_data(ends);
}
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes;
if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i);
}
if(num_axes >= 4)
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]);
return prog.add_instruction(op::squeeze{squeeze_axes}, l0);
}
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
{ {
nodes = get_nodes(graph, input_nodes); nodes = get_nodes(graph, input_nodes);
...@@ -644,10 +741,6 @@ struct tf_parser ...@@ -644,10 +741,6 @@ struct tf_parser
static literal parse_tensor(const tensorflow::TensorProto& t) static literal parse_tensor(const tensorflow::TensorProto& t)
{ {
std::vector<size_t> dims = parse_dims(t.tensor_shape()); std::vector<size_t> dims = parse_dims(t.tensor_shape());
if(dims.empty())
{
dims = {1};
}
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()); size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
if(!t.tensor_content().empty()) // has raw data if(!t.tensor_content().empty()) // has raw data
{ {
...@@ -658,17 +751,17 @@ struct tf_parser ...@@ -658,17 +751,17 @@ struct tf_parser
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()}; return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::uint16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16: case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()}; return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()}; case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()}; case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()}; return literal{{shape::double_type, dims}, s.data()};
...@@ -718,21 +811,23 @@ struct tf_parser ...@@ -718,21 +811,23 @@ struct tf_parser
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, get_data_vals(t.float_val(), shape_size)}; return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: case tensorflow::DataType::DT_INT8:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT16: case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)}; return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, get_data_vals(t.int64_val(), shape_size)}; return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
return literal{{shape::int32_type, dims}, get_data_vals(t.bool_val(), shape_size)}; return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF: case tensorflow::DataType::DT_HALF:
{ {
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size); std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
...@@ -742,7 +837,7 @@ struct tf_parser ...@@ -742,7 +837,7 @@ struct tf_parser
data_uint16.end(), data_uint16.end(),
std::back_inserter(data_half), std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); }); [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half}; return create_literal(shape::half_type, dims, data_half);
} }
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)}; return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
...@@ -811,9 +906,19 @@ struct tf_parser ...@@ -811,9 +906,19 @@ struct tf_parser
std::transform(input_dims.begin(), std::transform(input_dims.begin(),
input_dims.end(), input_dims.end(),
std::back_inserter(dims), std::back_inserter(dims),
[](tensorflow::TensorShapeProto_Dim dim) { return dim.size(); }); [](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); });
return dims; return dims;
} }
template <class T>
static literal
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
{
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
}; };
program parse_tf(const std::string& name, bool is_nhwc) program parse_tf(const std::string& name, bool is_nhwc)
......
...@@ -60,7 +60,7 @@ TEST_CASE(after_literal_broadcast) ...@@ -60,7 +60,7 @@ TEST_CASE(after_literal_broadcast)
auto l2 = p.add_literal(get_2()); auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2); auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
...@@ -91,7 +91,7 @@ TEST_CASE(after_param_broadcast) ...@@ -91,7 +91,7 @@ TEST_CASE(after_param_broadcast)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2); auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
......
...@@ -351,7 +351,7 @@ TEST_CASE(gemm_mutli_dim1_2_3) ...@@ -351,7 +351,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
float beta = 0.41; float beta = 0.41;
auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2); auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
auto l_beta = p.add_literal(beta); auto l_beta = p.add_literal(beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape()}, l_beta); auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape().lens()}, l_beta);
auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3); auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha); p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
......
...@@ -651,7 +651,7 @@ TEST_CASE(broadcast_test) ...@@ -651,7 +651,7 @@ TEST_CASE(broadcast_test)
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraphx::literal{a_shape, a_data}); auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraphx::literal{b_shape, b_data}); auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2); p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
auto output = result.get<int32_t>(); auto output = result.get<int32_t>();
...@@ -671,7 +671,7 @@ TEST_CASE(add_broadcast_test) ...@@ -671,7 +671,7 @@ TEST_CASE(add_broadcast_test)
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraphx::literal{a_shape, a_data}); auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraphx::literal{b_shape, b_data}); auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2); auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l1, l3); p.add_instruction(migraphx::op::add{}, l1, l3);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -809,11 +809,11 @@ TEST_CASE(imagescaler_test) ...@@ -809,11 +809,11 @@ TEST_CASE(imagescaler_test)
0.35, 0.35,
0.45}}); 0.45}});
auto scale_val = p.add_literal(2.f); auto scale_val = p.add_literal(2.f);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor); auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor);
auto bias_vals = p.add_literal( auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
......
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -36,7 +40,46 @@ TEST_CASE(non_standard_op) ...@@ -36,7 +40,46 @@ TEST_CASE(non_standard_op)
p.add_instruction(pass_op{}, c); p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{}); p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(transpose_gemm)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto ic = p.add_instruction(migraphx::op::identity{}, c);
p.add_instruction(migraphx::op::dot{}, ic, l);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
} }
TEST_CASE(transpose_standard_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(no_packed_unary_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/tanh.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct lowering_target
{
std::string name() const { return "gpu::lowering"; }
std::vector<migraphx::pass> get_passes(migraphx::context& gctx) const
{
auto& ctx = migraphx::any_cast<migraphx::gpu::context>(gctx);
return {migraphx::auto_contiguous{},
migraphx::gpu::lowering{ctx},
migraphx::dead_code_elimination{},
migraphx::eliminate_contiguous{},
migraphx::dead_code_elimination{}};
}
migraphx::gpu::context get_context() const { return migraphx::gpu::context{}; }
};
TEST_CASE(tanh_shape)
{
auto create_program = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto x = p.add_parameter("x", s);
auto tx = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto txh = p.add_instruction(migraphx::op::tanh{}, tx);
auto sum = p.add_instruction(migraphx::op::add{}, txh, txh);
p.add_instruction(migraphx::op::contiguous{}, sum);
return p;
};
auto p1 = create_program();
auto p2 = create_program();
EXPECT(p1 == p2);
p1.compile(lowering_target{});
p2.compile(lowering_target());
EXPECT(p1 == p2);
for(auto ins : iterator_for(p1))
{
if(ins->name() == "hip::allocate")
{
migraphx::shape new_s{migraphx::shape::float_type, {3, 2}, {1, 3}};
migraphx::instruction::replace(ins, ins->get_operator(), new_s, ins->inputs());
}
}
EXPECT(p1 != p2);
migraphx::run_passes(p2,
{migraphx::gpu::adjust_allocation{}, migraphx::dead_code_elimination{}});
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -236,8 +236,7 @@ struct test_exp : verify_program<test_exp> ...@@ -236,8 +236,7 @@ struct test_exp : verify_program<test_exp>
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> data{0.1f, 0.2f, 1.f, 2.f, 0.6f, 10.f}; auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
auto x = p.add_literal(s, data);
p.add_instruction(migraphx::op::exp{}, x); p.add_instruction(migraphx::op::exp{}, x);
return p; return p;
} }
...@@ -249,8 +248,7 @@ struct test_log : verify_program<test_log> ...@@ -249,8 +248,7 @@ struct test_log : verify_program<test_log>
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> data{0.1f, 0.2f, 1.f, 2.f, 0.6f, 100.f}; auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
auto x = p.add_literal(s, data);
p.add_instruction(migraphx::op::log{}, x); p.add_instruction(migraphx::op::log{}, x);
return p; return p;
} }
...@@ -327,6 +325,34 @@ struct test_tanh : verify_program<test_tanh> ...@@ -327,6 +325,34 @@ struct test_tanh : verify_program<test_tanh>
} }
}; };
struct test_trans_tanh : verify_program<test_trans_tanh>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx);
auto r = p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
struct test_slice_sin : verify_program<test_slice_sin>
{
migraphx::program create_program() const
{
migraphx::program p;
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
p.add_instruction(migraphx::op::sin{}, t);
return p;
}
};
struct test_asin : verify_program<test_asin> struct test_asin : verify_program<test_asin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -371,7 +397,7 @@ struct test_scale : verify_program<test_scale> ...@@ -371,7 +397,7 @@ struct test_scale : verify_program<test_scale>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", migraphx::shape::float_type); auto y = p.add_parameter("y", migraphx::shape::float_type);
auto scale = p.add_instruction(migraphx::op::scalar{s}, y); auto scale = p.add_instruction(migraphx::op::scalar{s.lens()}, y);
p.add_instruction(migraphx::op::mul{}, x, scale); p.add_instruction(migraphx::op::mul{}, x, scale);
return p; return p;
} }
...@@ -417,7 +443,7 @@ struct test_triadd2 : verify_program<test_triadd2> ...@@ -417,7 +443,7 @@ struct test_triadd2 : verify_program<test_triadd2>
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b); auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z); auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto sum = p.add_instruction(migraphx::op::add{}, x, y); auto sum = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraphx::op::add{}, sum, zb); p.add_instruction(migraphx::op::add{}, sum, zb);
return p; return p;
...@@ -432,7 +458,7 @@ struct test_add_broadcast : verify_program<test_add_broadcast> ...@@ -432,7 +458,7 @@ struct test_add_broadcast : verify_program<test_add_broadcast>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -446,7 +472,7 @@ struct test_add_broadcast2 : verify_program<test_add_broadcast2> ...@@ -446,7 +472,7 @@ struct test_add_broadcast2 : verify_program<test_add_broadcast2>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -460,7 +486,7 @@ struct test_add_broadcast3 : verify_program<test_add_broadcast3> ...@@ -460,7 +486,7 @@ struct test_add_broadcast3 : verify_program<test_add_broadcast3>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -474,7 +500,7 @@ struct test_add_broadcast4 : verify_program<test_add_broadcast4> ...@@ -474,7 +500,7 @@ struct test_add_broadcast4 : verify_program<test_add_broadcast4>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -488,7 +514,7 @@ struct test_add_broadcast5 : verify_program<test_add_broadcast5> ...@@ -488,7 +514,7 @@ struct test_add_broadcast5 : verify_program<test_add_broadcast5>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -503,7 +529,7 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast> ...@@ -503,7 +529,7 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto z = p.add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}}); auto z = p.add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}});
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
auto sum = p.add_instruction(migraphx::op::add{}, x, by); auto sum = p.add_instruction(migraphx::op::add{}, x, by);
p.add_instruction(migraphx::op::add{}, sum, z); p.add_instruction(migraphx::op::add{}, sum, z);
return p; return p;
...@@ -535,7 +561,7 @@ struct test_sub2 : verify_program<test_sub2> ...@@ -535,7 +561,7 @@ struct test_sub2 : verify_program<test_sub2>
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b); auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z); auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto diff = p.add_instruction(migraphx::op::sub{}, x, y); auto diff = p.add_instruction(migraphx::op::sub{}, x, y);
p.add_instruction(migraphx::op::sub{}, diff, zb); p.add_instruction(migraphx::op::sub{}, diff, zb);
return p; return p;
...@@ -674,6 +700,21 @@ struct test_abs : verify_program<test_abs> ...@@ -674,6 +700,21 @@ struct test_abs : verify_program<test_abs>
} }
}; };
struct test_trans_abs : verify_program<test_trans_abs>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto absx = p.add_instruction(migraphx::op::abs{}, tx);
auto r = p.add_instruction(migraphx::op::add{}, absx, absx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
struct test_leaky_relu : verify_program<test_leaky_relu> struct test_leaky_relu : verify_program<test_leaky_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args) ...@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
...@@ -373,7 +373,10 @@ TEST_CASE(gru_test_args) ...@@ -373,7 +373,10 @@ TEST_CASE(gru_test_args)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
...@@ -414,14 +417,20 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -414,14 +417,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip}, p.add_instruction(migraphx::op::gru{hs,
seq, {migraphx::op::sigmoid{},
w, migraphx::op::tanh{},
r, migraphx::op::sigmoid{},
bias, migraphx::op::tanh{}},
seq_len, migraphx::op::rnn_direction::bidirectional,
ih); clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
...@@ -445,15 +454,20 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -445,15 +454,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{ p.add_instruction(migraphx::op::gru{hs,
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, {migraphx::op::sigmoid{},
seq, migraphx::op::sigmoid{},
w, migraphx::op::sigmoid{},
r, migraphx::op::sigmoid{}},
bias, migraphx::op::rnn_direction::bidirectional,
seq_len, clip},
ih); seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
...@@ -479,7 +493,10 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -479,7 +493,10 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
...@@ -511,17 +528,20 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -511,17 +528,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::sigmoid{},
clip}, migraphx::op::tanh{},
seq, migraphx::op::tanh{}},
w, migraphx::op::rnn_direction::bidirectional,
r, clip},
bias, seq,
seq_len, w,
ih); r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
...@@ -546,7 +566,10 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -546,7 +566,10 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip}, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq, seq,
w, w,
r, r,
...@@ -576,15 +599,17 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -576,15 +599,17 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{ p.add_instruction(migraphx::op::gru{hs,
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip}, {migraphx::op::relu{}, migraphx::op::relu{}},
seq, migraphx::op::rnn_direction::reverse,
w, clip},
r, seq,
bias, w,
seq_len, r,
ih); bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
...@@ -826,7 +851,12 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -826,7 +851,12 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget}, migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq, seq,
w, w,
r, r,
...@@ -851,19 +881,21 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -851,19 +881,21 @@ TEST_CASE(lstm_forward_actv_func)
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(migraphx::op::lstm{hs, auto out_hs = p.add_instruction(
{migraphx::op::sigmoid{}}, migraphx::op::lstm{
migraphx::op::rnn_direction::forward, hs,
clip, {migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
input_forget}, migraphx::op::rnn_direction::forward,
seq, clip,
w, input_forget},
r, seq,
bias, w,
und, r,
und, bias,
und, und,
und); und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx");
...@@ -881,20 +913,21 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -881,20 +913,21 @@ TEST_CASE(lstm_forward_actv_func)
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::lstm{hs, migraphx::op::lstm{
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, hs,
migraphx::op::rnn_direction::forward, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
clip, migraphx::op::rnn_direction::forward,
input_forget}, clip,
seq, input_forget},
w, seq,
r, w,
bias, r,
seq_len, bias,
und, seq_len,
und, und,
und); und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx");
...@@ -993,7 +1026,12 @@ TEST_CASE(lstm_reverse) ...@@ -993,7 +1026,12 @@ TEST_CASE(lstm_reverse)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget}, migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
input_forget},
seq, seq,
w, w,
r, r,
...@@ -1037,21 +1075,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1037,21 +1075,25 @@ TEST_CASE(lstm_bidirectional)
auto ic = p.add_parameter("c0", ih_shape); auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape); auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{},
clip, migraphx::op::sigmoid{},
input_forget}, migraphx::op::tanh{},
seq, migraphx::op::tanh{}},
w, migraphx::op::rnn_direction::bidirectional,
r, clip,
bias, input_forget},
seq_len, seq,
ih, w,
ic, r,
pph); bias,
seq_len,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx");
...@@ -1067,21 +1109,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1067,21 +1109,25 @@ TEST_CASE(lstm_bidirectional)
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{},
clip, migraphx::op::sigmoid{},
input_forget}, migraphx::op::tanh{},
seq, migraphx::op::tanh{}},
w, migraphx::op::rnn_direction::bidirectional,
r, clip,
und, input_forget},
und, seq,
und, w,
und, r,
und); und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx");
...@@ -1098,21 +1144,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1098,21 +1144,25 @@ TEST_CASE(lstm_bidirectional)
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{},
clip, migraphx::op::sigmoid{},
input_forget}, migraphx::op::tanh{},
seq, migraphx::op::tanh{}},
w, migraphx::op::rnn_direction::bidirectional,
r, clip,
bias, input_forget},
und, seq,
und, w,
und, r,
und); bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx");
...@@ -1130,21 +1180,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1130,21 +1180,25 @@ TEST_CASE(lstm_bidirectional)
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{},
clip, migraphx::op::sigmoid{},
input_forget}, migraphx::op::tanh{},
seq, migraphx::op::tanh{}},
w, migraphx::op::rnn_direction::bidirectional,
r, clip,
bias, input_forget},
seq_len, seq,
und, w,
und, r,
und); bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx");
...@@ -1163,21 +1217,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1163,21 +1217,25 @@ TEST_CASE(lstm_bidirectional)
auto ih = p.add_parameter("h0", ih_shape); auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{},
clip, migraphx::op::sigmoid{},
input_forget}, migraphx::op::tanh{},
seq, migraphx::op::tanh{}},
w, migraphx::op::rnn_direction::bidirectional,
r, clip,
bias, input_forget},
seq_len, seq,
ih, w,
und, r,
und); bias,
seq_len,
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx");
...@@ -1197,21 +1255,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1197,21 +1255,25 @@ TEST_CASE(lstm_bidirectional)
auto ic = p.add_parameter("c0", ih_shape); auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{},
clip, migraphx::op::sigmoid{},
input_forget}, migraphx::op::tanh{},
seq, migraphx::op::tanh{}},
w, migraphx::op::rnn_direction::bidirectional,
r, clip,
bias, input_forget},
seq_len, seq,
ih, w,
ic, r,
und); bias,
seq_len,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx");
...@@ -1244,17 +1306,25 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1244,17 +1306,25 @@ TEST_CASE(lstm_bi_actv_funcs)
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, {migraphx::op::sigmoid{},
seq, migraphx::op::tanh{},
w, migraphx::op::tanh{},
r, migraphx::op::sigmoid{},
und, migraphx::op::tanh{},
und, migraphx::op::tanh{}},
und, migraphx::op::rnn_direction::bidirectional,
und, clip,
und); input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx");
...@@ -1273,7 +1343,12 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1273,7 +1343,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}}, {migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1304,7 +1379,12 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1304,7 +1379,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1337,6 +1417,8 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1337,6 +1417,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, {migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}}, migraphx::op::tanh{}},
...@@ -1376,6 +1458,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1376,6 +1458,7 @@ TEST_CASE(lstm_bi_actv_funcs)
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
......
...@@ -15,7 +15,7 @@ TEST_CASE(pytorch_conv_bias_test) ...@@ -15,7 +15,7 @@ TEST_CASE(pytorch_conv_bias_test)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4); p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = migraphx::parse_onnx("conv.onnx"); auto prog = migraphx::parse_onnx("conv.onnx");
...@@ -30,7 +30,7 @@ TEST_CASE(pytorch_conv_relu_maxpool) ...@@ -30,7 +30,7 @@ TEST_CASE(pytorch_conv_relu_maxpool)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5); auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
...@@ -52,7 +52,7 @@ TEST_CASE(pytorch_conv_bn_relu_maxpool) ...@@ -52,7 +52,7 @@ TEST_CASE(pytorch_conv_bn_relu_maxpool)
auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}}); auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6); auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraphx::op::relu{}, l6); auto l7 = p.add_instruction(migraphx::op::relu{}, l6);
...@@ -70,7 +70,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2) ...@@ -70,7 +70,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5); auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
...@@ -78,7 +78,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2) ...@@ -78,7 +78,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2)
auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}}); auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}}); auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8); auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape()}, l9); auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape().lens()}, l9);
auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11); auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraphx::op::relu{}, l12); auto l13 = p.add_instruction(migraphx::op::relu{}, l12);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
...@@ -108,9 +108,9 @@ TEST_CASE(imagescaler_test) ...@@ -108,9 +108,9 @@ TEST_CASE(imagescaler_test)
auto scale_val = p.add_literal(0.5f); auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal( auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor); auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = migraphx::parse_onnx("imagescaler_test.onnx"); auto prog = migraphx::parse_onnx("imagescaler_test.onnx");
...@@ -338,7 +338,7 @@ TEST_CASE(add_bcast_test) ...@@ -338,7 +338,7 @@ TEST_CASE(add_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_onnx("add_bcast_test.onnx"); auto prog = migraphx::parse_onnx("add_bcast_test.onnx");
...@@ -365,7 +365,7 @@ TEST_CASE(sub_bcast_test) ...@@ -365,7 +365,7 @@ TEST_CASE(sub_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l2); p.add_instruction(migraphx::op::sub{}, l0, l2);
auto prog = migraphx::parse_onnx("sub_bcast_test.onnx"); auto prog = migraphx::parse_onnx("sub_bcast_test.onnx");
...@@ -699,8 +699,7 @@ TEST_CASE(add_scalar_test) ...@@ -699,8 +699,7 @@ TEST_CASE(add_scalar_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1); p.add_instruction(migraphx::op::add{}, m0, m1);
......
...@@ -229,6 +229,36 @@ TEST_CASE(multibroadcast) ...@@ -229,6 +229,36 @@ TEST_CASE(multibroadcast)
} }
} }
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::op::broadcast{0, lens},
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::op::broadcast{2, lens},
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::op::broadcast{2, lens}, input);
}
}
TEST_CASE(gather) TEST_CASE(gather)
{ {
{ {
......
#include <migraphx/program.hpp>
#include <migraphx/ranges.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
migraphx::program create_program()
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type});
auto sum = p.add_instruction(sum_op{}, x, y);
auto one = p.add_literal(1);
p.add_instruction(sum_op{}, sum, one);
return p;
}
TEST_CASE(basic_graph_test)
{
migraphx::program p = create_program();
std::stringstream ss;
p.print_graph(ss);
std::string test = ss.str();
EXPECT(migraphx::contains(test, "digraph"));
EXPECT(migraphx::contains(test, "rankdir=LR"));
EXPECT(migraphx::contains(test, "\"@0\"[label=\"@literal\"]"));
EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]"));
EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]"));
EXPECT(migraphx::contains(test, "\"@3\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"@4\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"@3\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"@3\""));
EXPECT(migraphx::contains(test, "\"@3\" -> \"@4\""));
EXPECT(migraphx::contains(test, "\"@0\" -> \"@4\""));
EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]"));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/cpu/target.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -27,4 +31,78 @@ TEST_CASE(program_equality) ...@@ -27,4 +31,78 @@ TEST_CASE(program_equality)
EXPECT(x == y); EXPECT(x == y);
} }
TEST_CASE(program_copy)
{
auto create_program_1 = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5}};
std::vector<float> data(3 * 4 * 5);
std::iota(data.begin(), data.end(), 1.0f);
auto l2 = p.add_literal(migraphx::literal(s, data));
auto p1 = p.add_parameter("x", s);
auto po = p.add_outline(s);
auto sum = p.add_instruction(migraphx::op::add{}, l2, po);
p.add_instruction(migraphx::op::mul{}, sum, p1);
return p;
};
{
auto p1 = create_program_1();
migraphx::program p2{};
p2 = p1;
p2.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2(p1);
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2 = create_program();
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
migraphx::program p1;
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
migraphx::shape s3{migraphx::shape::float_type, {2, 6}};
auto para1 = p1.add_parameter("m1", s1);
auto para2 = p1.add_parameter("m2", s2);
auto para3 = p1.add_parameter("m3", s3);
p1.add_instruction(migraphx::op::dot{0.31f, 0.28f}, para1, para2, para3);
migraphx::program p2{};
p2 = p1;
EXPECT(p2 == p1);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p2 == p1);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/constant_propagate.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/mul.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -9,12 +11,12 @@ struct const_prop_target ...@@ -9,12 +11,12 @@ struct const_prop_target
std::string name() const { return "const_prop"; } std::string name() const { return "const_prop"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraphx::constant_propagate{}, migraphx::dead_code_elimination{}}; return {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}};
} }
migraphx::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
TEST_CASE(const_add1) TEST_CASE(const_add)
{ {
migraphx::program p1; migraphx::program p1;
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
...@@ -29,7 +31,7 @@ TEST_CASE(const_add1) ...@@ -29,7 +31,7 @@ TEST_CASE(const_add1)
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(const_add2) TEST_CASE(const_add_parameter)
{ {
migraphx::program p1; migraphx::program p1;
auto one = p1.add_parameter("one", {migraphx::shape::int32_type, {1}}); auto one = p1.add_parameter("one", {migraphx::shape::int32_type, {1}});
...@@ -44,7 +46,7 @@ TEST_CASE(const_add2) ...@@ -44,7 +46,7 @@ TEST_CASE(const_add2)
EXPECT(p1 != p2); EXPECT(p1 != p2);
} }
TEST_CASE(const_add3) TEST_CASE(const_multiadd)
{ {
migraphx::program p1; migraphx::program p1;
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
...@@ -60,4 +62,54 @@ TEST_CASE(const_add3) ...@@ -60,4 +62,54 @@ TEST_CASE(const_add3)
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(const_add_mul)
{
migraphx::program p1;
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto mul = p1.add_instruction(migraphx::op::mul{}, two, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, mul);
auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{});
migraphx::program p2;
auto total = p2.add_literal(7);
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
TEST_CASE(const_add_scalar)
{
migraphx::program p1;
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
auto two = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(2));
auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{});
migraphx::program p2;
auto total =
p2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}});
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
TEST_CASE(const_scalar)
{
migraphx::program p1;
{
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
p1.add_instruction(pass_op{}, one);
}
p1.compile(const_prop_target{});
migraphx::program p2;
{
auto one = p2.add_instruction(migraphx::op::scalar{{2, 2}}, p2.add_literal(1));
p2.add_instruction(pass_op{}, one);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
2
0 Placeholder*
dtype0*
shape
:
2
1 Placeholder*
dtype0*
shape
:
F
matmul1MatMul01*
T0*
transpose_a(*
transpose_b("
\ No newline at end of file
:
0 Placeholder*
shape:*
dtype0
:
1 Placeholder*
dtype0*
shape:

mul1Mul01*
T0"
\ No newline at end of file
.
0 Placeholder*
shape:*
dtype0
.
1 Placeholder*
dtype0*
shape:
.
2 Placeholder*
dtype0*
shape:
4
pack1Pack012*
T0*
axis*
N"
\ No newline at end of file
:
0 Placeholder*
dtype0*
shape:
:
1 Placeholder*
dtype0*
shape:
:
2 Placeholder*
dtype0*
shape:
4
pack1Pack012*
T0*
axis*
N"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment