Commit 15385fb1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents f7f02979 b606ed4f
......@@ -11,7 +11,7 @@
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/constant_propagate.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
......@@ -20,6 +20,7 @@
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp>
......@@ -47,7 +48,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
//dead_code_elimination{},
simplify_algebra{},
dead_code_elimination{},
constant_propagate{},
propagate_constant{},
dead_code_elimination{},
auto_contiguous{},
//simplify_reshapes{},
......@@ -57,6 +58,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{},
eliminate_contiguous{},
dead_code_elimination{},
adjust_allocation{},
dead_code_elimination{},
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx},
......
......@@ -110,6 +110,7 @@ struct tf_parser
add_generic_op("Relu", op::relu{});
add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{});
add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
......@@ -117,12 +118,15 @@ struct tf_parser
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pack", &tf_parser::parse_pack);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
}
template <class F>
......@@ -149,7 +153,7 @@ struct tf_parser
template <class T>
void add_binary_op(std::string name, T x)
{
add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
add_op(name, [this, x](const attribute_map& attributes, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
auto l0 = args[1];
......@@ -211,7 +215,7 @@ struct tf_parser
template <class T>
void add_generic_op(std::string name, T x)
{
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
}
......@@ -234,7 +238,7 @@ struct tf_parser
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
return prog.add_instruction(op::add{}, args[0], l0);
}
......@@ -335,6 +339,32 @@ struct tf_parser
return prog.add_instruction(op, {args[0], weights});
}
instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
bool transa = false;
bool transb = false;
if(contains(attributes, "transpose_a"))
{
transa = attributes.at("transpose_a").b();
}
if(contains(attributes, "transpose_b"))
{
transb = attributes.at("transpose_a").b();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
return prog.add_instruction(op::dot{}, l1, l2);
}
instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -353,6 +383,33 @@ struct tf_parser
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref parse_pack(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
// reinterpret as unsqueeze with concat
std::vector<instruction_ref> unsqueezed_args;
int64_t axis = 0;
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
size_t input_size = args.front()->get_shape().lens().size();
if(axis > input_size)
{
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size));
}
// check if input arg needs axis to be converted to NCHW
if(input_size >= 4)
axis = parse_axis(axis);
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
}
instruction_ref
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
......@@ -480,6 +537,46 @@ struct tf_parser
return prog.add_instruction(op, args[0]);
}
instruction_ref parse_stridedslice(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size();
if(num_axes >= 4)
{
reorder_data(starts);
reorder_data(ends);
}
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes;
if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i);
}
if(num_axes >= 4)
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]);
return prog.add_instruction(op::squeeze{squeeze_axes}, l0);
}
void parse_graph(const tensorflow::GraphDef& graph)
{
nodes = get_nodes(graph, input_nodes);
......@@ -644,10 +741,6 @@ struct tf_parser
static literal parse_tensor(const tensorflow::TensorProto& t)
{
std::vector<size_t> dims = parse_dims(t.tensor_shape());
if(dims.empty())
{
dims = {1};
}
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
if(!t.tensor_content().empty()) // has raw data
{
......@@ -658,17 +751,17 @@ struct tf_parser
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, s.data()};
return literal{{shape::uint16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, s.data()};
return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()};
......@@ -718,21 +811,23 @@ struct tf_parser
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, get_data_vals(t.float_val(), shape_size)};
return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, get_data_vals(t.int_val(), shape_size)};
return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, get_data_vals(t.int64_val(), shape_size)};
return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL:
return literal{{shape::int32_type, dims}, get_data_vals(t.bool_val(), shape_size)};
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF:
{
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
......@@ -742,7 +837,7 @@ struct tf_parser
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half};
return create_literal(shape::half_type, dims, data_half);
}
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
......@@ -811,9 +906,19 @@ struct tf_parser
std::transform(input_dims.begin(),
input_dims.end(),
std::back_inserter(dims),
[](tensorflow::TensorShapeProto_Dim dim) { return dim.size(); });
[](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); });
return dims;
}
template <class T>
static literal
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
{
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
};
program parse_tf(const std::string& name, bool is_nhwc)
......
......@@ -60,7 +60,7 @@ TEST_CASE(after_literal_broadcast)
auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2);
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted());
......@@ -91,7 +91,7 @@ TEST_CASE(after_param_broadcast)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2);
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted());
......
......@@ -351,7 +351,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
float beta = 0.41;
auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
auto l_beta = p.add_literal(beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape()}, l_beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape().lens()}, l_beta);
auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
p.compile(migraphx::cpu::target{});
......
......@@ -651,7 +651,7 @@ TEST_CASE(broadcast_test)
uint64_t axis = 0;
auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2);
p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
auto output = result.get<int32_t>();
......@@ -671,7 +671,7 @@ TEST_CASE(add_broadcast_test)
uint64_t axis = 0;
auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2);
auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l1, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......@@ -809,11 +809,11 @@ TEST_CASE(imagescaler_test)
0.35,
0.45}});
auto scale_val = p.add_literal(2.f);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor);
auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <basic_ops.hpp>
......@@ -36,7 +40,46 @@ TEST_CASE(non_standard_op)
p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(transpose_gemm)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto ic = p.add_instruction(migraphx::op::identity{}, c);
p.add_instruction(migraphx::op::dot{}, ic, l);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
}
TEST_CASE(transpose_standard_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(no_packed_unary_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/tanh.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct lowering_target
{
std::string name() const { return "gpu::lowering"; }
std::vector<migraphx::pass> get_passes(migraphx::context& gctx) const
{
auto& ctx = migraphx::any_cast<migraphx::gpu::context>(gctx);
return {migraphx::auto_contiguous{},
migraphx::gpu::lowering{ctx},
migraphx::dead_code_elimination{},
migraphx::eliminate_contiguous{},
migraphx::dead_code_elimination{}};
}
migraphx::gpu::context get_context() const { return migraphx::gpu::context{}; }
};
TEST_CASE(tanh_shape)
{
auto create_program = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto x = p.add_parameter("x", s);
auto tx = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto txh = p.add_instruction(migraphx::op::tanh{}, tx);
auto sum = p.add_instruction(migraphx::op::add{}, txh, txh);
p.add_instruction(migraphx::op::contiguous{}, sum);
return p;
};
auto p1 = create_program();
auto p2 = create_program();
EXPECT(p1 == p2);
p1.compile(lowering_target{});
p2.compile(lowering_target());
EXPECT(p1 == p2);
for(auto ins : iterator_for(p1))
{
if(ins->name() == "hip::allocate")
{
migraphx::shape new_s{migraphx::shape::float_type, {3, 2}, {1, 3}};
migraphx::instruction::replace(ins, ins->get_operator(), new_s, ins->inputs());
}
}
EXPECT(p1 != p2);
migraphx::run_passes(p2,
{migraphx::gpu::adjust_allocation{}, migraphx::dead_code_elimination{}});
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -236,8 +236,7 @@ struct test_exp : verify_program<test_exp>
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> data{0.1f, 0.2f, 1.f, 2.f, 0.6f, 10.f};
auto x = p.add_literal(s, data);
auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
p.add_instruction(migraphx::op::exp{}, x);
return p;
}
......@@ -249,8 +248,7 @@ struct test_log : verify_program<test_log>
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> data{0.1f, 0.2f, 1.f, 2.f, 0.6f, 100.f};
auto x = p.add_literal(s, data);
auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
p.add_instruction(migraphx::op::log{}, x);
return p;
}
......@@ -327,6 +325,34 @@ struct test_tanh : verify_program<test_tanh>
}
};
struct test_trans_tanh : verify_program<test_trans_tanh>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx);
auto r = p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
struct test_slice_sin : verify_program<test_slice_sin>
{
migraphx::program create_program() const
{
migraphx::program p;
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
p.add_instruction(migraphx::op::sin{}, t);
return p;
}
};
struct test_asin : verify_program<test_asin>
{
migraphx::program create_program() const
......@@ -371,7 +397,7 @@ struct test_scale : verify_program<test_scale>
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", migraphx::shape::float_type);
auto scale = p.add_instruction(migraphx::op::scalar{s}, y);
auto scale = p.add_instruction(migraphx::op::scalar{s.lens()}, y);
p.add_instruction(migraphx::op::mul{}, x, scale);
return p;
}
......@@ -417,7 +443,7 @@ struct test_triadd2 : verify_program<test_triadd2>
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto sum = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraphx::op::add{}, sum, zb);
return p;
......@@ -432,7 +458,7 @@ struct test_add_broadcast : verify_program<test_add_broadcast>
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by);
return p;
}
......@@ -446,7 +472,7 @@ struct test_add_broadcast2 : verify_program<test_add_broadcast2>
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by);
return p;
}
......@@ -460,7 +486,7 @@ struct test_add_broadcast3 : verify_program<test_add_broadcast3>
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by);
return p;
}
......@@ -474,7 +500,7 @@ struct test_add_broadcast4 : verify_program<test_add_broadcast4>
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by);
return p;
}
......@@ -488,7 +514,7 @@ struct test_add_broadcast5 : verify_program<test_add_broadcast5>
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by);
return p;
}
......@@ -503,7 +529,7 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto z = p.add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}});
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
auto sum = p.add_instruction(migraphx::op::add{}, x, by);
p.add_instruction(migraphx::op::add{}, sum, z);
return p;
......@@ -535,7 +561,7 @@ struct test_sub2 : verify_program<test_sub2>
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto diff = p.add_instruction(migraphx::op::sub{}, x, y);
p.add_instruction(migraphx::op::sub{}, diff, zb);
return p;
......@@ -674,6 +700,21 @@ struct test_abs : verify_program<test_abs>
}
};
struct test_trans_abs : verify_program<test_trans_abs>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto absx = p.add_instruction(migraphx::op::abs{}, tx);
auto r = p.add_instruction(migraphx::op::add{}, absx, absx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
struct test_leaky_relu : verify_program<test_leaky_relu>
{
migraphx::program create_program() const
......
......@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
......@@ -373,7 +373,10 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
......@@ -414,14 +417,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
......@@ -445,15 +454,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
......@@ -479,7 +493,10 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
......@@ -511,17 +528,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
......@@ -546,7 +566,10 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
......@@ -576,15 +599,17 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::relu{}, migraphx::op::relu{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
......@@ -826,7 +851,12 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget},
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
......@@ -851,19 +881,21 @@ TEST_CASE(lstm_forward_actv_func)
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx");
......@@ -881,20 +913,21 @@ TEST_CASE(lstm_forward_actv_func)
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx");
......@@ -993,7 +1026,12 @@ TEST_CASE(lstm_reverse)
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget},
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
input_forget},
seq,
w,
r,
......@@ -1037,21 +1075,25 @@ TEST_CASE(lstm_bidirectional)
auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx");
......@@ -1067,21 +1109,25 @@ TEST_CASE(lstm_bidirectional)
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx");
......@@ -1098,21 +1144,25 @@ TEST_CASE(lstm_bidirectional)
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx");
......@@ -1130,21 +1180,25 @@ TEST_CASE(lstm_bidirectional)
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx");
......@@ -1163,21 +1217,25 @@ TEST_CASE(lstm_bidirectional)
auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx");
......@@ -1197,21 +1255,25 @@ TEST_CASE(lstm_bidirectional)
auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx");
......@@ -1244,17 +1306,25 @@ TEST_CASE(lstm_bi_actv_funcs)
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs, {}, migraphx::op::rnn_direction::bidirectional, clip, input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx");
......@@ -1273,7 +1343,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}},
{migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1304,7 +1379,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1337,6 +1417,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
......@@ -1376,6 +1458,7 @@ TEST_CASE(lstm_bi_actv_funcs)
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
......
......@@ -15,7 +15,7 @@ TEST_CASE(pytorch_conv_bias_test)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = migraphx::parse_onnx("conv.onnx");
......@@ -30,7 +30,7 @@ TEST_CASE(pytorch_conv_relu_maxpool)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
......@@ -52,7 +52,7 @@ TEST_CASE(pytorch_conv_bn_relu_maxpool)
auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraphx::op::relu{}, l6);
......@@ -70,7 +70,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
......@@ -78,7 +78,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2)
auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape()}, l9);
auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape().lens()}, l9);
auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraphx::op::relu{}, l12);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
......@@ -108,9 +108,9 @@ TEST_CASE(imagescaler_test)
auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = migraphx::parse_onnx("imagescaler_test.onnx");
......@@ -338,7 +338,7 @@ TEST_CASE(add_bcast_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1);
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_onnx("add_bcast_test.onnx");
......@@ -365,7 +365,7 @@ TEST_CASE(sub_bcast_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1);
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l2);
auto prog = migraphx::parse_onnx("sub_bcast_test.onnx");
......@@ -699,8 +699,7 @@ TEST_CASE(add_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1);
......
......@@ -229,6 +229,36 @@ TEST_CASE(multibroadcast)
}
}
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::op::broadcast{0, lens},
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::op::broadcast{2, lens},
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::op::broadcast{2, lens}, input);
}
}
TEST_CASE(gather)
{
{
......
#include <migraphx/program.hpp>
#include <migraphx/ranges.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
migraphx::program create_program()
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type});
auto sum = p.add_instruction(sum_op{}, x, y);
auto one = p.add_literal(1);
p.add_instruction(sum_op{}, sum, one);
return p;
}
TEST_CASE(basic_graph_test)
{
migraphx::program p = create_program();
std::stringstream ss;
p.print_graph(ss);
std::string test = ss.str();
EXPECT(migraphx::contains(test, "digraph"));
EXPECT(migraphx::contains(test, "rankdir=LR"));
EXPECT(migraphx::contains(test, "\"@0\"[label=\"@literal\"]"));
EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]"));
EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]"));
EXPECT(migraphx::contains(test, "\"@3\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"@4\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"@3\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"@3\""));
EXPECT(migraphx::contains(test, "\"@3\" -> \"@4\""));
EXPECT(migraphx::contains(test, "\"@0\" -> \"@4\""));
EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]"));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -2,6 +2,10 @@
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/cpu/target.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
......@@ -27,4 +31,78 @@ TEST_CASE(program_equality)
EXPECT(x == y);
}
TEST_CASE(program_copy)
{
auto create_program_1 = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5}};
std::vector<float> data(3 * 4 * 5);
std::iota(data.begin(), data.end(), 1.0f);
auto l2 = p.add_literal(migraphx::literal(s, data));
auto p1 = p.add_parameter("x", s);
auto po = p.add_outline(s);
auto sum = p.add_instruction(migraphx::op::add{}, l2, po);
p.add_instruction(migraphx::op::mul{}, sum, p1);
return p;
};
{
auto p1 = create_program_1();
migraphx::program p2{};
p2 = p1;
p2.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2(p1);
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2 = create_program();
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
migraphx::program p1;
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
migraphx::shape s3{migraphx::shape::float_type, {2, 6}};
auto para1 = p1.add_parameter("m1", s1);
auto para2 = p1.add_parameter("m2", s2);
auto para3 = p1.add_parameter("m3", s3);
p1.add_instruction(migraphx::op::dot{0.31f, 0.28f}, para1, para2, para3);
migraphx::program p2{};
p2 = p1;
EXPECT(p2 == p1);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p2 == p1);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/constant_propagate.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/mul.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -9,12 +11,12 @@ struct const_prop_target
std::string name() const { return "const_prop"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::constant_propagate{}, migraphx::dead_code_elimination{}};
return {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(const_add1)
TEST_CASE(const_add)
{
migraphx::program p1;
auto one = p1.add_literal(1);
......@@ -29,7 +31,7 @@ TEST_CASE(const_add1)
EXPECT(p1 == p2);
}
TEST_CASE(const_add2)
TEST_CASE(const_add_parameter)
{
migraphx::program p1;
auto one = p1.add_parameter("one", {migraphx::shape::int32_type, {1}});
......@@ -44,7 +46,7 @@ TEST_CASE(const_add2)
EXPECT(p1 != p2);
}
TEST_CASE(const_add3)
TEST_CASE(const_multiadd)
{
migraphx::program p1;
auto one = p1.add_literal(1);
......@@ -60,4 +62,54 @@ TEST_CASE(const_add3)
EXPECT(p1 == p2);
}
TEST_CASE(const_add_mul)
{
migraphx::program p1;
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto mul = p1.add_instruction(migraphx::op::mul{}, two, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, mul);
auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{});
migraphx::program p2;
auto total = p2.add_literal(7);
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
TEST_CASE(const_add_scalar)
{
migraphx::program p1;
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
auto two = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(2));
auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{});
migraphx::program p2;
auto total =
p2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}});
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
TEST_CASE(const_scalar)
{
migraphx::program p1;
{
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
p1.add_instruction(pass_op{}, one);
}
p1.compile(const_prop_target{});
migraphx::program p2;
{
auto one = p2.add_instruction(migraphx::op::scalar{{2, 2}}, p2.add_literal(1));
p2.add_instruction(pass_op{}, one);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
2
0 Placeholder*
dtype0*
shape
:
2
1 Placeholder*
dtype0*
shape
:
F
matmul1MatMul01*
T0*
transpose_a(*
transpose_b("
\ No newline at end of file
:
0 Placeholder*
shape:*
dtype0
:
1 Placeholder*
dtype0*
shape:

mul1Mul01*
T0"
\ No newline at end of file
.
0 Placeholder*
shape:*
dtype0
.
1 Placeholder*
dtype0*
shape:
.
2 Placeholder*
dtype0*
shape:
4
pack1Pack012*
T0*
axis*
N"
\ No newline at end of file
:
0 Placeholder*
dtype0*
shape:
:
1 Placeholder*
dtype0*
shape:
:
2 Placeholder*
dtype0*
shape:
4
pack1Pack012*
T0*
axis*
N"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment