Commit 3272b22e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 94e3a2e4
...@@ -56,11 +56,11 @@ struct tf_parser ...@@ -56,11 +56,11 @@ struct tf_parser
std::vector<tensorflow::NodeDef> input_nodes; std::vector<tensorflow::NodeDef> input_nodes;
std::vector<std::string> output_node_names; std::vector<std::string> output_node_names;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
module* mm = prog.get_main_module(); module* mm = prog.get_main_module();
bool is_nhwc = true; bool is_nhwc = true;
unsigned int batch_size = 1; unsigned int batch_size = 1;
int default_dim_value = 1; int default_dim_value = 1;
std::unordered_map<std::string, std::vector<int>> map_input_dims; std::unordered_map<std::string, std::vector<int>> map_input_dims;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
......
...@@ -19,8 +19,8 @@ struct parse_concat : op_parser<parse_concat> ...@@ -19,8 +19,8 @@ struct parse_concat : op_parser<parse_concat>
{ {
// get index for axis within args // get index for axis within args
int axis_idx = info.attributes.at("N").i(); int axis_idx = info.attributes.at("N").i();
int64_t axis = args[axis_idx]->eval().at<int64_t>(); int64_t axis = args[axis_idx]->eval().at<int64_t>();
auto op = make_op("concat", {{"axis", axis}}); auto op = make_op("concat", {{"axis", axis}});
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return info.add_instruction( return info.add_instruction(
op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1)); op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
......
...@@ -19,9 +19,9 @@ struct parse_slice : op_parser<parse_slice> ...@@ -19,9 +19,9 @@ struct parse_slice : op_parser<parse_slice>
const tf_parser::node_info& info, const tf_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
auto size = args[2]->eval().get<int32_t>().to_vector(); auto size = args[2]->eval().get<int32_t>().to_vector();
auto axes = args[0]->get_shape().lens(); auto axes = args[0]->get_shape().lens();
int num_axes = axes.size(); int num_axes = axes.size();
std::vector<int64_t> axes_int64(axes.begin(), axes.end()); std::vector<int64_t> axes_int64(axes.begin(), axes.end());
......
...@@ -245,7 +245,7 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph) ...@@ -245,7 +245,7 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
const std::string& name = input.name(); const std::string& name = input.name();
attribute_map input_attrs = get_attributes(input); attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type()); shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<int> dims = parse_dims(input_attrs.at("shape").shape()); std::vector<int> dims = parse_dims(input_attrs.at("shape").shape());
if(contains(map_input_dims, name)) if(contains(map_input_dims, name))
{ {
...@@ -424,7 +424,7 @@ shape::type_t tf_parser::parse_type(const tensorflow::DataType t) const ...@@ -424,7 +424,7 @@ shape::type_t tf_parser::parse_type(const tensorflow::DataType t) const
literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const
{ {
std::vector<int> dims = parse_dims(t.tensor_shape()); std::vector<int> dims = parse_dims(t.tensor_shape());
int shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int>()); int shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int>());
if(!t.tensor_content().empty()) // has raw data if(!t.tensor_content().empty()) // has raw data
{ {
const std::string& s = t.tensor_content(); const std::string& s = t.tensor_content();
......
...@@ -150,7 +150,7 @@ TEST_CASE(strided_shape) ...@@ -150,7 +150,7 @@ TEST_CASE(strided_shape)
{ {
std::vector<int> lens = {2, 2}; std::vector<int> lens = {2, 2};
std::vector<int> strides = {1, 2}; std::vector<int> strides = {1, 2};
auto s = migraphx::shape(migraphx_shape_float_type, lens, strides); auto s = migraphx::shape(migraphx_shape_float_type, lens, strides);
EXPECT(s.lengths() == lens); EXPECT(s.lengths() == lens);
EXPECT(s.strides() == strides); EXPECT(s.strides() == strides);
} }
...@@ -167,8 +167,8 @@ TEST_CASE(set_loop_default_iter_num) ...@@ -167,8 +167,8 @@ TEST_CASE(set_loop_default_iter_num)
{ {
migraphx::onnx_options option; migraphx::onnx_options option;
option.set_default_loop_iterations(15); option.set_default_loop_iterations(15);
auto p = migraphx::parse_onnx("loop_default_test.onnx", option); auto p = migraphx::parse_onnx("loop_default_test.onnx", option);
auto out_shapes = p.get_output_shapes(); auto out_shapes = p.get_output_shapes();
std::vector<int> out_lens0 = {1}; std::vector<int> out_lens0 = {1};
EXPECT(out_shapes[0].lengths() == out_lens0); EXPECT(out_shapes[0].lengths() == out_lens0);
std::vector<int> out_lens1 = {15, 1}; std::vector<int> out_lens1 = {15, 1};
......
...@@ -48,11 +48,10 @@ TEST_CASE(if_pl_test) ...@@ -48,11 +48,10 @@ TEST_CASE(if_pl_test)
char ccond = cond; char ccond = cond;
pp.add("cond", migraphx::argument(param_shapes["cond"], &ccond)); pp.add("cond", migraphx::argument(param_shapes["cond"], &ccond));
auto outputs = p.eval(pp); auto outputs = p.eval(pp);
auto output = outputs[0]; auto output = outputs[0];
auto lens = output.get_shape().lengths(); auto lens = output.get_shape().lengths();
auto elem_num = auto elem_num = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<int>());
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<int>());
float* data_ptr = reinterpret_cast<float*>(output.data()); float* data_ptr = reinterpret_cast<float*>(output.data());
std::vector<float> ret(data_ptr, data_ptr + elem_num); std::vector<float> ret(data_ptr, data_ptr + elem_num);
...@@ -97,11 +96,10 @@ TEST_CASE(loop_test) ...@@ -97,11 +96,10 @@ TEST_CASE(loop_test)
std::vector<float> yd = {2.0}; std::vector<float> yd = {2.0};
pp.add("b", migraphx::argument(bbs, yd.data())); pp.add("b", migraphx::argument(bbs, yd.data()));
auto outputs = p.eval(pp); auto outputs = p.eval(pp);
auto output = outputs[0]; auto output = outputs[0];
auto lens = output.get_shape().lengths(); auto lens = output.get_shape().lengths();
auto elem_num = auto elem_num = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<int>());
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<int>());
float* data_ptr = reinterpret_cast<float*>(output.data()); float* data_ptr = reinterpret_cast<float*>(output.data());
std::vector<std::vector<float>> ret; std::vector<std::vector<float>> ret;
ret.push_back({data_ptr, data_ptr + elem_num}); ret.push_back({data_ptr, data_ptr + elem_num});
......
...@@ -116,12 +116,12 @@ TEST_CASE(simple) ...@@ -116,12 +116,12 @@ TEST_CASE(simple)
auto create_test_program = [] { auto create_test_program = [] {
migraphx::module m; migraphx::module m;
auto a1 = m.add_instruction(allocate{create_shape(1)}); auto a1 = m.add_instruction(allocate{create_shape(1)});
auto m1 = m.add_instruction(simple_op{}, a1); auto m1 = m.add_instruction(simple_op{}, a1);
auto a2 = m.add_instruction(allocate{create_shape(1)}); auto a2 = m.add_instruction(allocate{create_shape(1)});
auto m2 = m.add_instruction(simple_op{}, a2); auto m2 = m.add_instruction(simple_op{}, a2);
int axis = 0; int axis = 0;
auto a3 = m.add_instruction(allocate{create_shape(2)}); auto a3 = m.add_instruction(allocate{create_shape(2)});
m.add_instruction(concat(axis), m1, m2, a3); m.add_instruction(concat(axis), m1, m2, a3);
return m; return m;
}; };
...@@ -149,12 +149,12 @@ TEST_CASE(negative_axis1) ...@@ -149,12 +149,12 @@ TEST_CASE(negative_axis1)
auto create_test_program = [] { auto create_test_program = [] {
migraphx::module m; migraphx::module m;
auto a1 = m.add_instruction(allocate{create_shape(2, 2)}); auto a1 = m.add_instruction(allocate{create_shape(2, 2)});
auto m1 = m.add_instruction(simple_op{}, a1); auto m1 = m.add_instruction(simple_op{}, a1);
auto a2 = m.add_instruction(allocate{create_shape(2, 2)}); auto a2 = m.add_instruction(allocate{create_shape(2, 2)});
auto m2 = m.add_instruction(simple_op{}, a2); auto m2 = m.add_instruction(simple_op{}, a2);
int axis = -1; int axis = -1;
auto a3 = m.add_instruction(allocate{create_shape(4, 2)}); auto a3 = m.add_instruction(allocate{create_shape(4, 2)});
m.add_instruction(concat(axis), m1, m2, a3); m.add_instruction(concat(axis), m1, m2, a3);
return m; return m;
}; };
...@@ -172,12 +172,12 @@ TEST_CASE(negative_axis2) ...@@ -172,12 +172,12 @@ TEST_CASE(negative_axis2)
auto create_test_program = [] { auto create_test_program = [] {
migraphx::module m; migraphx::module m;
auto a1 = m.add_instruction(allocate{create_shape(2, 2)}); auto a1 = m.add_instruction(allocate{create_shape(2, 2)});
auto m1 = m.add_instruction(simple_op{}, a1); auto m1 = m.add_instruction(simple_op{}, a1);
auto a2 = m.add_instruction(allocate{create_shape(2, 2)}); auto a2 = m.add_instruction(allocate{create_shape(2, 2)});
auto m2 = m.add_instruction(simple_op{}, a2); auto m2 = m.add_instruction(simple_op{}, a2);
int axis = -2; int axis = -2;
auto a3 = m.add_instruction(allocate{create_shape(4, 2)}); auto a3 = m.add_instruction(allocate{create_shape(4, 2)});
m.add_instruction(concat(axis), m1, m2, a3); m.add_instruction(concat(axis), m1, m2, a3);
return m; return m;
}; };
...@@ -205,12 +205,12 @@ TEST_CASE(negative_axis3) ...@@ -205,12 +205,12 @@ TEST_CASE(negative_axis3)
auto create_test_program = [] { auto create_test_program = [] {
migraphx::module m; migraphx::module m;
auto a1 = m.add_instruction(allocate{create_shape(1, 2, 2)}); auto a1 = m.add_instruction(allocate{create_shape(1, 2, 2)});
auto m1 = m.add_instruction(simple_op{}, a1); auto m1 = m.add_instruction(simple_op{}, a1);
auto a2 = m.add_instruction(allocate{create_shape(1, 2, 2)}); auto a2 = m.add_instruction(allocate{create_shape(1, 2, 2)});
auto m2 = m.add_instruction(simple_op{}, a2); auto m2 = m.add_instruction(simple_op{}, a2);
int axis = -2; int axis = -2;
auto a3 = m.add_instruction(allocate{create_shape(1, 4, 2)}); auto a3 = m.add_instruction(allocate{create_shape(1, 4, 2)});
m.add_instruction(concat(axis), m1, m2, a3); m.add_instruction(concat(axis), m1, m2, a3);
return m; return m;
}; };
...@@ -238,12 +238,12 @@ TEST_CASE(reversed) ...@@ -238,12 +238,12 @@ TEST_CASE(reversed)
auto create_test_program = [] { auto create_test_program = [] {
migraphx::module m; migraphx::module m;
auto a1 = m.add_instruction(allocate{create_shape(1)}); auto a1 = m.add_instruction(allocate{create_shape(1)});
auto m1 = m.add_instruction(simple_op{}, a1); auto m1 = m.add_instruction(simple_op{}, a1);
auto a2 = m.add_instruction(allocate{create_shape(1)}); auto a2 = m.add_instruction(allocate{create_shape(1)});
auto m2 = m.add_instruction(simple_op{}, a2); auto m2 = m.add_instruction(simple_op{}, a2);
int axis = 0; int axis = 0;
auto a3 = m.add_instruction(allocate{create_shape(2)}); auto a3 = m.add_instruction(allocate{create_shape(2)});
m.add_instruction(concat(axis), m2, m1, a3); m.add_instruction(concat(axis), m2, m1, a3);
return m; return m;
}; };
...@@ -269,20 +269,20 @@ TEST_CASE(reversed) ...@@ -269,20 +269,20 @@ TEST_CASE(reversed)
TEST_CASE(nested) TEST_CASE(nested)
{ {
auto concat_test_program = [](auto& m) { auto concat_test_program = [](auto& m) {
auto a1 = m.add_instruction(allocate{create_shape(1)}); auto a1 = m.add_instruction(allocate{create_shape(1)});
auto m1 = m.add_instruction(simple_op{}, a1); auto m1 = m.add_instruction(simple_op{}, a1);
auto a2 = m.add_instruction(allocate{create_shape(1)}); auto a2 = m.add_instruction(allocate{create_shape(1)});
auto m2 = m.add_instruction(simple_op{}, a2); auto m2 = m.add_instruction(simple_op{}, a2);
int axis = 0; int axis = 0;
auto a3 = m.add_instruction(allocate{create_shape(2)}); auto a3 = m.add_instruction(allocate{create_shape(2)});
return m.add_instruction(concat(axis), m1, m2, a3); return m.add_instruction(concat(axis), m1, m2, a3);
}; };
auto create_test_program = [&] { auto create_test_program = [&] {
migraphx::module m; migraphx::module m;
auto concat1 = concat_test_program(m); auto concat1 = concat_test_program(m);
auto concat2 = concat_test_program(m); auto concat2 = concat_test_program(m);
int axis = 0; int axis = 0;
auto a1 = m.add_instruction(allocate{create_shape(4)}); auto a1 = m.add_instruction(allocate{create_shape(4)});
m.add_instruction(concat(axis), concat1, concat2, a1); m.add_instruction(concat(axis), concat1, concat2, a1);
return m; return m;
}; };
...@@ -323,9 +323,9 @@ TEST_CASE(basic) ...@@ -323,9 +323,9 @@ TEST_CASE(basic)
auto m2 = m.add_instruction(simple_op{}, a2); auto m2 = m.add_instruction(simple_op{}, a2);
auto a3 = auto a3 =
m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}}); m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
auto p3 = m.add_instruction(simple_op{}, a3); auto p3 = m.add_instruction(simple_op{}, a3);
int axis = 1; int axis = 1;
auto a4 = m.add_instruction( auto a4 = m.add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
m.add_instruction(concat(axis), m1, m2, p3, a4); m.add_instruction(concat(axis), m1, m2, p3, a4);
return m; return m;
...@@ -366,9 +366,9 @@ TEST_CASE(wont_work) ...@@ -366,9 +366,9 @@ TEST_CASE(wont_work)
auto m2 = m.add_instruction(simple_op{}, a2); auto m2 = m.add_instruction(simple_op{}, a2);
auto a3 = auto a3 =
m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}}); m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = m.add_instruction(simple_op{}, a3); auto p3 = m.add_instruction(simple_op{}, a3);
int axis = 1; int axis = 1;
auto a4 = m.add_instruction( auto a4 = m.add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
m.add_instruction(concat(axis), m1, m2, p3, a4); m.add_instruction(concat(axis), m1, m2, p3, a4);
return m; return m;
...@@ -383,9 +383,9 @@ TEST_CASE(wont_work) ...@@ -383,9 +383,9 @@ TEST_CASE(wont_work)
auto m2 = m.add_instruction(simple_op{}, a2); auto m2 = m.add_instruction(simple_op{}, a2);
auto a3 = auto a3 =
m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}}); m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = m.add_instruction(simple_op{}, a3); auto p3 = m.add_instruction(simple_op{}, a3);
int axis = 1; int axis = 1;
auto a4 = m.add_instruction( auto a4 = m.add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
m.add_instruction(concat(axis), m1, m2, p3, a4); m.add_instruction(concat(axis), m1, m2, p3, a4);
return m; return m;
......
...@@ -615,7 +615,7 @@ struct driver ...@@ -615,7 +615,7 @@ struct driver
[](const std::string& name) -> std::vector<std::string> { return {name}; }; [](const std::string& name) -> std::vector<std::string> { return {name}; };
std::vector<argument> arguments = {}; std::vector<argument> arguments = {};
std::vector<std::string> failed = {}; std::vector<std::string> failed = {};
int ran = 0; int ran = 0;
bool quiet = false; bool quiet = false;
}; };
......
...@@ -33,12 +33,12 @@ migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode ...@@ -33,12 +33,12 @@ migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode
TEST_CASE(rnn_test_bidirectional) TEST_CASE(rnn_test_bidirectional)
{ {
int sl = 5; // sequence len int sl = 5; // sequence len
int bs = 3; // batch size int bs = 3; // batch size
int hs = 20; // hidden size int hs = 20; // hidden size
int is = 10; // input size int is = 10; // input size
int nd = 2; // num directions int nd = 2; // num directions
float clip = 0.0f; float clip = 0.0f;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}};
...@@ -79,12 +79,12 @@ TEST_CASE(rnn_test_bidirectional) ...@@ -79,12 +79,12 @@ TEST_CASE(rnn_test_bidirectional)
TEST_CASE(rnn_test_one_direction) TEST_CASE(rnn_test_one_direction)
{ {
int sl = 5; // sequence len int sl = 5; // sequence len
int bs = 3; // batch size int bs = 3; // batch size
int hs = 20; // hidden size int hs = 20; // hidden size
int is = 10; // input size int is = 10; // input size
int nd = 1; // num directions int nd = 1; // num directions
float clip = 0.0f; float clip = 0.0f;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}};
...@@ -220,12 +220,12 @@ TEST_CASE(rnn_test_one_direction) ...@@ -220,12 +220,12 @@ TEST_CASE(rnn_test_one_direction)
TEST_CASE(gru_test) TEST_CASE(gru_test)
{ {
int sl = 5; // sequence len int sl = 5; // sequence len
int bs = 3; // batch size int bs = 3; // batch size
int hs = 20; // hidden size int hs = 20; // hidden size
int is = 10; // input size int is = 10; // input size
int nd = 2; // num directions int nd = 2; // num directions
float clip = 0.0f; float clip = 0.0f;
// forward // forward
{ {
nd = 1; nd = 1;
...@@ -352,12 +352,12 @@ TEST_CASE(gru_test) ...@@ -352,12 +352,12 @@ TEST_CASE(gru_test)
TEST_CASE(gru_test_args) TEST_CASE(gru_test_args)
{ {
int sl = 5; // sequence len int sl = 5; // sequence len
int bs = 3; // batch size int bs = 3; // batch size
int hs = 20; // hidden size int hs = 20; // hidden size
int is = 10; // input size int is = 10; // input size
int nd = 2; // num directions int nd = 2; // num directions
float clip = 0.0f; float clip = 0.0f;
// 3 arguments // 3 arguments
{ {
...@@ -474,12 +474,12 @@ TEST_CASE(gru_test_args) ...@@ -474,12 +474,12 @@ TEST_CASE(gru_test_args)
TEST_CASE(gru_test_actv_funcs) TEST_CASE(gru_test_actv_funcs)
{ {
int sl = 5; // sequence len int sl = 5; // sequence len
int bs = 3; // batch size int bs = 3; // batch size
int hs = 20; // hidden size int hs = 20; // hidden size
int is = 10; // input size int is = 10; // input size
int nd = 2; // num directions int nd = 2; // num directions
float clip = 0.0f; float clip = 0.0f;
// bidirection, 0 actv function // bidirection, 0 actv function
{ {
nd = 2; nd = 2;
...@@ -733,11 +733,11 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -733,11 +733,11 @@ TEST_CASE(gru_test_actv_funcs)
TEST_CASE(lstm_forward) TEST_CASE(lstm_forward)
{ {
int sl = 5; // sequence len int sl = 5; // sequence len
int bs = 3; // batch size int bs = 3; // batch size
int hs = 20; // hidden size int hs = 20; // hidden size
int is = 10; // input size int is = 10; // input size
int nd = 1; // num directions int nd = 1; // num directions
float clip = 0.0f; float clip = 0.0f;
int input_forget = 1; int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
...@@ -1072,11 +1072,11 @@ TEST_CASE(lstm_forward) ...@@ -1072,11 +1072,11 @@ TEST_CASE(lstm_forward)
// activation functions // activation functions
TEST_CASE(lstm_forward_actv_func) TEST_CASE(lstm_forward_actv_func)
{ {
int sl = 5; // sequence len int sl = 5; // sequence len
int bs = 3; // batch size int bs = 3; // batch size
int hs = 20; // hidden size int hs = 20; // hidden size
int is = 10; // input size int is = 10; // input size
int nd = 1; // num directions int nd = 1; // num directions
float clip = 0.0f; float clip = 0.0f;
int input_forget = 1; int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
...@@ -1196,11 +1196,11 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -1196,11 +1196,11 @@ TEST_CASE(lstm_forward_actv_func)
TEST_CASE(lstm_reverse) TEST_CASE(lstm_reverse)
{ {
int sl = 5; // sequence len int sl = 5; // sequence len
int bs = 3; // batch size int bs = 3; // batch size
int hs = 20; // hidden size int hs = 20; // hidden size
int is = 10; // input size int is = 10; // input size
int nd = 1; // num directions int nd = 1; // num directions
float clip = 0.0f; float clip = 0.0f;
int input_forget = 1; int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
...@@ -1321,11 +1321,11 @@ TEST_CASE(lstm_reverse) ...@@ -1321,11 +1321,11 @@ TEST_CASE(lstm_reverse)
TEST_CASE(lstm_bidirectional) TEST_CASE(lstm_bidirectional)
{ {
int sl = 5; // sequence len int sl = 5; // sequence len
int bs = 3; // batch size int bs = 3; // batch size
int hs = 20; // hidden size int hs = 20; // hidden size
int is = 10; // input size int is = 10; // input size
int nd = 2; // num directions int nd = 2; // num directions
float clip = 0.0f; float clip = 0.0f;
int input_forget = 1; int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
...@@ -1573,11 +1573,11 @@ TEST_CASE(lstm_bidirectional) ...@@ -1573,11 +1573,11 @@ TEST_CASE(lstm_bidirectional)
TEST_CASE(lstm_bi_actv_funcs) TEST_CASE(lstm_bi_actv_funcs)
{ {
int sl = 5; // sequence len int sl = 5; // sequence len
int bs = 3; // batch size int bs = 3; // batch size
int hs = 20; // hidden size int hs = 20; // hidden size
int is = 10; // input size int is = 10; // input size
int nd = 2; // num directions int nd = 2; // num directions
float clip = 0.0f; float clip = 0.0f;
int input_forget = 1; int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
......
...@@ -2552,9 +2552,9 @@ TEST_CASE(min_test) ...@@ -2552,9 +2552,9 @@ TEST_CASE(min_test)
TEST_CASE(multinomial_test) TEST_CASE(multinomial_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
int sample_size = 10; int sample_size = 10;
float seed = 0.0f; float seed = 0.0f;
auto input = mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {1, 10}}); auto input = mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {1, 10}});
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input); auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
...@@ -2596,7 +2596,7 @@ TEST_CASE(multinomial_int64_test) ...@@ -2596,7 +2596,7 @@ TEST_CASE(multinomial_int64_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
int sample_size = 10; int sample_size = 10;
float seed = 1.0f; float seed = 1.0f;
migraphx::shape::type_t dtype = migraphx::shape::type_t::int64_type; migraphx::shape::type_t dtype = migraphx::shape::type_t::int64_type;
...@@ -3973,7 +3973,7 @@ TEST_CASE(scatter_test) ...@@ -3973,7 +3973,7 @@ TEST_CASE(scatter_test)
TEST_CASE(selu_test) TEST_CASE(selu_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<int> lens = {2, 3}; std::vector<int> lens = {2, 3};
migraphx::shape s{migraphx::shape::double_type, lens}; migraphx::shape s{migraphx::shape::double_type, lens};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
......
...@@ -371,7 +371,7 @@ TEST_CASE(gru) ...@@ -371,7 +371,7 @@ TEST_CASE(gru)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -404,7 +404,7 @@ TEST_CASE(gru) ...@@ -404,7 +404,7 @@ TEST_CASE(gru)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -437,7 +437,7 @@ TEST_CASE(gru) ...@@ -437,7 +437,7 @@ TEST_CASE(gru)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 2; int num_dirct = 2;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -470,7 +470,7 @@ TEST_CASE(gru) ...@@ -470,7 +470,7 @@ TEST_CASE(gru)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -501,7 +501,7 @@ TEST_CASE(gru) ...@@ -501,7 +501,7 @@ TEST_CASE(gru)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -532,7 +532,7 @@ TEST_CASE(gru) ...@@ -532,7 +532,7 @@ TEST_CASE(gru)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 2; int num_dirct = 2;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -615,7 +615,7 @@ TEST_CASE(lstm) ...@@ -615,7 +615,7 @@ TEST_CASE(lstm)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -644,7 +644,7 @@ TEST_CASE(lstm) ...@@ -644,7 +644,7 @@ TEST_CASE(lstm)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -677,7 +677,7 @@ TEST_CASE(lstm) ...@@ -677,7 +677,7 @@ TEST_CASE(lstm)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 2; int num_dirct = 2;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -710,7 +710,7 @@ TEST_CASE(lstm) ...@@ -710,7 +710,7 @@ TEST_CASE(lstm)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -741,7 +741,7 @@ TEST_CASE(lstm) ...@@ -741,7 +741,7 @@ TEST_CASE(lstm)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -772,7 +772,7 @@ TEST_CASE(lstm) ...@@ -772,7 +772,7 @@ TEST_CASE(lstm)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 2; int num_dirct = 2;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
...@@ -1155,7 +1155,7 @@ TEST_CASE(rnn) ...@@ -1155,7 +1155,7 @@ TEST_CASE(rnn)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
...@@ -1186,7 +1186,7 @@ TEST_CASE(rnn) ...@@ -1186,7 +1186,7 @@ TEST_CASE(rnn)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
...@@ -1217,7 +1217,7 @@ TEST_CASE(rnn) ...@@ -1217,7 +1217,7 @@ TEST_CASE(rnn)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 2; int num_dirct = 2;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
...@@ -1248,7 +1248,7 @@ TEST_CASE(rnn) ...@@ -1248,7 +1248,7 @@ TEST_CASE(rnn)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
...@@ -1277,7 +1277,7 @@ TEST_CASE(rnn) ...@@ -1277,7 +1277,7 @@ TEST_CASE(rnn)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 1; int num_dirct = 1;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
...@@ -1306,7 +1306,7 @@ TEST_CASE(rnn) ...@@ -1306,7 +1306,7 @@ TEST_CASE(rnn)
int hidden_size = 4; int hidden_size = 4;
int input_size = 3; int input_size = 3;
int num_dirct = 2; int num_dirct = 2;
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
......
...@@ -462,10 +462,10 @@ TEST_CASE(op_capture) ...@@ -462,10 +462,10 @@ TEST_CASE(op_capture)
}; };
{ {
auto p = create_program_float(); auto p = create_program_float();
auto op_capture_p = create_program_op(); auto op_capture_p = create_program_op();
migraphx::target t = migraphx::ref::target{}; migraphx::target t = migraphx::ref::target{};
int param_index = 0; int param_index = 0;
migraphx::run_passes( migraphx::run_passes(
p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}}); p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}});
EXPECT(p == op_capture_p); EXPECT(p == op_capture_p);
...@@ -537,10 +537,10 @@ TEST_CASE(op_capture_subgraph) ...@@ -537,10 +537,10 @@ TEST_CASE(op_capture_subgraph)
}; };
{ {
auto p = create_program(); auto p = create_program();
auto op_capture_p = create_program_op(); auto op_capture_p = create_program_op();
migraphx::target t = migraphx::ref::target{}; migraphx::target t = migraphx::ref::target{};
int param_index = 0; int param_index = 0;
migraphx::run_passes( migraphx::run_passes(
p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}}); p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}});
...@@ -630,7 +630,7 @@ TEST_CASE(dot_float) ...@@ -630,7 +630,7 @@ TEST_CASE(dot_float)
const std::vector<std::pair<float, float>> quant_params = { const std::vector<std::pair<float, float>> quant_params = {
{0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; {0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
auto p = create_program(); auto p = create_program();
int param_index = 0; int param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes( migraphx::run_passes(
...@@ -1239,7 +1239,7 @@ TEST_CASE(test_op_capture) ...@@ -1239,7 +1239,7 @@ TEST_CASE(test_op_capture)
migraphx::program capture_p = p; migraphx::program capture_p = p;
migraphx::target t = migraphx::ref::target{}; migraphx::target t = migraphx::ref::target{};
int param_index = 0; int param_index = 0;
migraphx::run_passes(capture_p, migraphx::run_passes(capture_p,
{migraphx::capture_arguments_pass{{"dot"}, calc, &param_index}}); {migraphx::capture_arguments_pass{{"dot"}, calc, &param_index}});
......
...@@ -622,10 +622,10 @@ TEST_CASE(batch_norm_inference_test) ...@@ -622,10 +622,10 @@ TEST_CASE(batch_norm_inference_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
const int width = 2; const int width = 2;
const int height = 2; const int height = 2;
const int channels = 4; const int channels = 4;
const int batches = 2; const int batches = 2;
const float x_val = 8.0; const float x_val = 8.0;
const float mean_val = 2.0; const float mean_val = 2.0;
const float variance_val = 4.0; const float variance_val = 4.0;
...@@ -749,8 +749,7 @@ TEST_CASE(concat_test) ...@@ -749,8 +749,7 @@ TEST_CASE(concat_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({2, 6}))); EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({2, 6})));
EXPECT( EXPECT(migraphx::verify_range(result.get_shape().strides(), std::vector<int>({6, 1})));
migraphx::verify_range(result.get_shape().strides(), std::vector<int>({6, 1})));
} }
{ {
...@@ -774,8 +773,7 @@ TEST_CASE(concat_test) ...@@ -774,8 +773,7 @@ TEST_CASE(concat_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({2, 6}))); EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({2, 6})));
EXPECT( EXPECT(migraphx::verify_range(result.get_shape().strides(), std::vector<int>({6, 1})));
migraphx::verify_range(result.get_shape().strides(), std::vector<int>({6, 1})));
} }
{ {
...@@ -799,8 +797,7 @@ TEST_CASE(concat_test) ...@@ -799,8 +797,7 @@ TEST_CASE(concat_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({6, 2}))); EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({6, 2})));
EXPECT( EXPECT(migraphx::verify_range(result.get_shape().strides(), std::vector<int>({2, 1})));
migraphx::verify_range(result.get_shape().strides(), std::vector<int>({2, 1})));
} }
{ {
...@@ -824,8 +821,7 @@ TEST_CASE(concat_test) ...@@ -824,8 +821,7 @@ TEST_CASE(concat_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({6, 2}))); EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({6, 2})));
EXPECT( EXPECT(migraphx::verify_range(result.get_shape().strides(), std::vector<int>({2, 1})));
migraphx::verify_range(result.get_shape().strides(), std::vector<int>({2, 1})));
} }
} }
...@@ -2718,7 +2714,7 @@ TEST_CASE(multinomial_test) ...@@ -2718,7 +2714,7 @@ TEST_CASE(multinomial_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
int sample_size = 100000; int sample_size = 100000;
float seed = 0.0f; float seed = 0.0f;
std::mt19937 gen(seed); std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); std::vector<float> rand_samples(sample_size);
......
...@@ -82,8 +82,7 @@ struct stream_free_op ...@@ -82,8 +82,7 @@ struct stream_free_op
struct wait_event struct wait_event
{ {
std::shared_ptr<std::vector<int>> wait_for = std::shared_ptr<std::vector<int>> wait_for = std::make_shared<std::vector<int>>();
std::make_shared<std::vector<int>>();
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
...@@ -104,8 +103,7 @@ struct wait_event ...@@ -104,8 +103,7 @@ struct wait_event
using instruction_map = std::unordered_map<migraphx::instruction_ref, int>; using instruction_map = std::unordered_map<migraphx::instruction_ref, int>;
using int_map = std::unordered_map<int, int>; using int_map = std::unordered_map<int, int>;
using wait_map = using wait_map = std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<int>>>;
std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<int>>>;
struct schedule_model_test struct schedule_model_test
{ {
...@@ -211,10 +209,7 @@ std::vector<T> unique(std::vector<T> x) ...@@ -211,10 +209,7 @@ std::vector<T> unique(std::vector<T> x)
return x; return x;
} }
std::vector<int> get_wait_for(std::vector<int> wait_for) std::vector<int> get_wait_for(std::vector<int> wait_for) { return unique(std::move(wait_for)); }
{
return unique(std::move(wait_for));
}
std::vector<int> get_wait_for(int wait_on, std::vector<int> wait_for) std::vector<int> get_wait_for(int wait_on, std::vector<int> wait_for)
{ {
......
...@@ -22,8 +22,8 @@ struct reflectable_type ...@@ -22,8 +22,8 @@ struct reflectable_type
class3 class3
}; };
std::vector<int> ints = {}; std::vector<int> ints = {};
std::string name = ""; std::string name = "";
float fvalue = 0.0; float fvalue = 0.0;
empty_type et{}; empty_type et{};
simple_enum se = simple1; simple_enum se = simple1;
class_enum ce = class_enum::class1; class_enum ce = class_enum::class1;
...@@ -74,7 +74,7 @@ TEST_CASE(serialize_reflectable_type) ...@@ -74,7 +74,7 @@ TEST_CASE(serialize_reflectable_type)
TEST_CASE(serialize_empty_array) TEST_CASE(serialize_empty_array)
{ {
std::vector<int> ints = {}; std::vector<int> ints = {};
migraphx::value v = migraphx::to_value(ints); migraphx::value v = migraphx::to_value(ints);
EXPECT(v.is_array()); EXPECT(v.is_array());
EXPECT(v.empty()); EXPECT(v.empty());
v.push_back(1); v.push_back(1);
......
...@@ -337,15 +337,12 @@ TEST_CASE(test_shape4_nonpacked) ...@@ -337,15 +337,12 @@ TEST_CASE(test_shape4_nonpacked)
std::array<int, 4> offsets = {{5, 10, 0, 6}}; std::array<int, 4> offsets = {{5, 10, 0, 6}};
std::array<int, 4> adj_lens = {{0, 0, 0, 0}}; std::array<int, 4> adj_lens = {{0, 0, 0, 0}};
std::transform( std::transform(lens.begin(), lens.end(), offsets.begin(), adj_lens.begin(), std::plus<int>());
lens.begin(), lens.end(), offsets.begin(), adj_lens.begin(), std::plus<int>());
// adj_lens should be: { 105, 42, 8, 14 } // adj_lens should be: { 105, 42, 8, 14 }
std::vector<int> strides(4); std::vector<int> strides(4);
strides.back() = 1; strides.back() = 1;
std::partial_sum(adj_lens.rbegin(), std::partial_sum(
adj_lens.rend() - 1, adj_lens.rbegin(), adj_lens.rend() - 1, strides.rbegin() + 1, std::multiplies<int>());
strides.rbegin() + 1,
std::multiplies<int>());
migraphx::shape s{migraphx::shape::float_type, lens, strides}; migraphx::shape s{migraphx::shape::float_type, lens, strides};
EXPECT(not s.standard()); EXPECT(not s.standard());
......
...@@ -20,11 +20,10 @@ ...@@ -20,11 +20,10 @@
#include "test.hpp" #include "test.hpp"
migraphx::program migraphx::program parse_tf(const std::string& name,
parse_tf(const std::string& name, bool is_nhwc,
bool is_nhwc, const std::unordered_map<std::string, std::vector<int>>& dim_params = {},
const std::unordered_map<std::string, std::vector<int>>& dim_params = {}, const std::vector<std::string>& output_node_names = {})
const std::vector<std::string>& output_node_names = {})
{ {
return migraphx::parse_tf(name, return migraphx::parse_tf(name,
migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names}); migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names});
...@@ -750,9 +749,9 @@ TEST_CASE(slice_test) ...@@ -750,9 +749,9 @@ TEST_CASE(slice_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
int num_axes = 2; int num_axes = 2;
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}});
migraphx::shape s0{migraphx::shape::int32_type, {num_axes}}; migraphx::shape s0{migraphx::shape::int32_type, {num_axes}};
mm->add_literal(migraphx::literal{s0, {1, 0}}); mm->add_literal(migraphx::literal{s0, {1, 0}});
mm->add_literal(migraphx::literal{s0, {2, -1}}); mm->add_literal(migraphx::literal{s0, {2, -1}});
......
...@@ -6,10 +6,8 @@ ...@@ -6,10 +6,8 @@
struct test_conv_bn_add : verify_program<test_conv_bn_add> struct test_conv_bn_add : verify_program<test_conv_bn_add>
{ {
static migraphx::instruction_ref add_bn(migraphx::module& m, static migraphx::instruction_ref
migraphx::instruction_ref x, add_bn(migraphx::module& m, migraphx::instruction_ref x, int channels, int seed = 1)
int channels,
int seed = 1)
{ {
migraphx::shape vars{migraphx::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed))); auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed)));
...@@ -23,7 +21,7 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add> ...@@ -23,7 +21,7 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
int ichannels = 64; int ichannels = 64;
int ochannels = 256; int ochannels = 256;
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
......
...@@ -17,7 +17,7 @@ struct test_gru_bidirct : verify_program<test_gru_bidirct> ...@@ -17,7 +17,7 @@ struct test_gru_bidirct : verify_program<test_gru_bidirct>
int hidden_size = 5; int hidden_size = 5;
int input_size = 8; int input_size = 8;
int num_dirct = 2; int num_dirct = 2;
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
......
...@@ -17,7 +17,7 @@ struct test_gru_bidirct_3args : verify_program<test_gru_bidirct_3args> ...@@ -17,7 +17,7 @@ struct test_gru_bidirct_3args : verify_program<test_gru_bidirct_3args>
int hidden_size = 5; int hidden_size = 5;
int input_size = 8; int input_size = 8;
int num_dirct = 2; int num_dirct = 2;
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment