Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -9,16 +9,17 @@ struct test_pad : verify_program<test_pad> ...@@ -9,16 +9,17 @@ struct test_pad : verify_program<test_pad>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {1, 96, 165, 165}}; migraphx::shape s0{migraphx::shape::int32_type, {1, 96, 165, 165}};
std::vector<int64_t> pads0 = {0, 0, 0, 0, 0, 0, 1, 1}; std::vector<int64_t> pads0 = {0, 0, 0, 0, 0, 0, 1, 1};
std::vector<int64_t> pads1 = {0, 0, 0, 0, 1, 1, 1, 1}; std::vector<int64_t> pads1 = {0, 0, 0, 0, 1, 1, 1, 1};
std::vector<int64_t> pads2 = {1, 1, 1, 1, 0, 0, 0, 0}; std::vector<int64_t> pads2 = {1, 1, 1, 1, 0, 0, 0, 0};
std::vector<int64_t> pads3 = {1, 0, 1, 0, 1, 0, 2, 0}; std::vector<int64_t> pads3 = {1, 0, 1, 0, 1, 0, 2, 0};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
p.add_instruction(migraphx::op::pad{pads0}, l0); mm->add_instruction(migraphx::op::pad{pads0}, l0);
p.add_instruction(migraphx::op::pad{pads1}, l0); mm->add_instruction(migraphx::op::pad{pads1}, l0);
p.add_instruction(migraphx::op::pad{pads2}, l0); mm->add_instruction(migraphx::op::pad{pads2}, l0);
p.add_instruction(migraphx::op::pad{pads3}, l0); mm->add_instruction(migraphx::op::pad{pads3}, l0);
return p; return p;
} }
}; };
...@@ -9,14 +9,15 @@ struct test_pad_highest : verify_program<test_pad_highest> ...@@ -9,14 +9,15 @@ struct test_pad_highest : verify_program<test_pad_highest>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::half> data0(4); std::vector<migraphx::half> data0(4);
std::iota(data0.begin(), data0.end(), 0); std::iota(data0.begin(), data0.end(), 0);
migraphx::shape s0{migraphx::shape::half_type, {2, 2}}; migraphx::shape s0{migraphx::shape::half_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0}); auto l0 = mm->add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{}; migraphx::op::pad op{};
op.value = std::numeric_limits<float>::max(); op.value = std::numeric_limits<float>::max();
op.pads = {0, 0, 1, 1}; op.pads = {0, 0, 1, 1};
p.add_instruction(op, l0); mm->add_instruction(op, l0);
return p; return p;
} }
}; };
...@@ -9,13 +9,14 @@ struct test_pad_int8 : verify_program<test_pad_int8> ...@@ -9,13 +9,14 @@ struct test_pad_int8 : verify_program<test_pad_int8>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int8_t> data0 = {0, 1, 2, 3}; std::vector<int8_t> data0 = {0, 1, 2, 3};
migraphx::shape s0{migraphx::shape::float_type, {2, 2}}; migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0}); auto l0 = mm->add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{}; migraphx::op::pad op{};
op.value = std::numeric_limits<int8_t>::lowest(); op.value = std::numeric_limits<int8_t>::lowest();
op.pads = {0, 0, 1, 1}; op.pads = {0, 0, 1, 1};
p.add_instruction(op, l0); mm->add_instruction(op, l0);
return p; return p;
} }
}; };
...@@ -9,14 +9,15 @@ struct test_pad_lowest : verify_program<test_pad_lowest> ...@@ -9,14 +9,15 @@ struct test_pad_lowest : verify_program<test_pad_lowest>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::half> data0(4); std::vector<migraphx::half> data0(4);
std::iota(data0.begin(), data0.end(), 0); std::iota(data0.begin(), data0.end(), 0);
migraphx::shape s0{migraphx::shape::half_type, {2, 2}}; migraphx::shape s0{migraphx::shape::half_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0}); auto l0 = mm->add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{}; migraphx::op::pad op{};
op.value = std::numeric_limits<float>::lowest(); op.value = std::numeric_limits<float>::lowest();
op.pads = {0, 0, 1, 1}; op.pads = {0, 0, 1, 1};
p.add_instruction(op, l0); mm->add_instruction(op, l0);
return p; return p;
} }
}; };
...@@ -9,10 +9,11 @@ struct test_pad_transposed : verify_program<test_pad_transposed> ...@@ -9,10 +9,11 @@ struct test_pad_transposed : verify_program<test_pad_transposed>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {1, 224, 224, 3}}; migraphx::shape s{migraphx::shape::int32_type, {1, 224, 224, 3}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, x); auto t = mm->add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, x);
p.add_instruction(migraphx::op::pad{{0, 0, 2, 2, 0, 0, 3, 3}}, t); mm->add_instruction(migraphx::op::pad{{0, 0, 2, 2, 0, 0, 3, 3}}, t);
return p; return p;
} }
}; };
...@@ -9,12 +9,13 @@ struct test_pooling_autopad : verify_program<test_pooling_autopad> ...@@ -9,12 +9,13 @@ struct test_pooling_autopad : verify_program<test_pooling_autopad>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 3, 63, 63}}; migraphx::shape s0{migraphx::shape::float_type, {1, 3, 63, 63}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::op::pooling op{"max"}; migraphx::op::pooling op{"max"};
op.lengths = {2, 2}; op.lengths = {2, 2};
op.stride = {2, 2}; op.stride = {2, 2};
p.add_instruction(op, l0); mm->add_instruction(op, l0);
return p; return p;
} }
}; };
...@@ -9,11 +9,12 @@ struct test_pow : verify_program<test_pow> ...@@ -9,11 +9,12 @@ struct test_pow : verify_program<test_pow>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> vec_e(s.elements(), 2.0f); std::vector<float> vec_e(s.elements(), 2.0f);
auto b = p.add_parameter("x", s); auto b = mm->add_parameter("x", s);
auto e = p.add_literal(migraphx::literal(s, vec_e)); auto e = mm->add_literal(migraphx::literal(s, vec_e));
p.add_instruction(migraphx::op::pow{}, b, e); mm->add_instruction(migraphx::op::pow{}, b, e);
return p; return p;
} }
}; };
...@@ -9,11 +9,12 @@ struct test_prelu_brcst : verify_program<test_prelu_brcst> ...@@ -9,11 +9,12 @@ struct test_prelu_brcst : verify_program<test_prelu_brcst>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto slp = p.add_parameter("slp", s); auto slp = mm->add_parameter("slp", s);
auto r = p.add_instruction(migraphx::op::prelu{}, x, slp); auto r = mm->add_instruction(migraphx::op::prelu{}, x, slp);
p.add_return({r}); mm->add_return({r});
return p; return p;
} }
......
...@@ -9,9 +9,10 @@ struct test_recip : verify_program<test_recip> ...@@ -9,9 +9,10 @@ struct test_recip : verify_program<test_recip>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}}; migraphx::shape s{migraphx::shape::double_type, {3}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
p.add_instruction(migraphx::op::recip{}, x); mm->add_instruction(migraphx::op::recip{}, x);
return p; return p;
} }
}; };
...@@ -10,9 +10,10 @@ struct test_reduce_op_large : verify_program<test_reduce_op_large<Op, Axis, T>> ...@@ -10,9 +10,10 @@ struct test_reduce_op_large : verify_program<test_reduce_op_large<Op, Axis, T>>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{T, {3, 1026, 4, 3}}; migraphx::shape s{T, {3, 1026, 4, 3}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
p.add_instruction(Op{{1}}, x); mm->add_instruction(Op{{1}}, x);
return p; return p;
}; };
}; };
......
...@@ -10,9 +10,10 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>> ...@@ -10,9 +10,10 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{T, {3, 4, 8, 8}}; migraphx::shape s{T, {3, 4, 8, 8}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
p.add_instruction(Op{{1}}, x); mm->add_instruction(Op{{1}}, x);
return p; return p;
}; };
}; };
......
...@@ -9,9 +9,10 @@ struct test_relu_lrn : verify_program<test_relu_lrn> ...@@ -9,9 +9,10 @@ struct test_relu_lrn : verify_program<test_relu_lrn>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}}); auto* mm = p.get_main_module();
auto y = p.add_instruction(migraphx::op::relu{}, x); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}});
p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1.0, 5}, y); auto y = mm->add_instruction(migraphx::op::relu{}, x);
mm->add_instruction(migraphx::op::lrn{0.0001, 0.75, 1.0, 5}, y);
return p; return p;
} }
}; };
...@@ -16,21 +16,22 @@ struct test_rnn_3args : verify_program<test_rnn_3args> ...@@ -16,21 +16,22 @@ struct test_rnn_3args : verify_program<test_rnn_3args>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
p.add_instruction(migraphx::op::rnn{hidden_size, mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
r); r);
return p; return p;
} }
......
...@@ -16,24 +16,25 @@ struct test_rnn_4args : verify_program<test_rnn_4args> ...@@ -16,24 +16,25 @@ struct test_rnn_4args : verify_program<test_rnn_4args>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
p.add_instruction(migraphx::op::rnn{hidden_size, mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
r, r,
bias); bias);
return p; return p;
} }
......
...@@ -16,28 +16,29 @@ struct test_rnn_5args : verify_program<test_rnn_5args> ...@@ -16,28 +16,29 @@ struct test_rnn_5args : verify_program<test_rnn_5args>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und); und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output); mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p; return p;
} }
......
...@@ -16,24 +16,25 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args> ...@@ -16,24 +16,25 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
r); r);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output); mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p; return p;
} }
......
...@@ -16,31 +16,32 @@ struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional> ...@@ -16,31 +16,32 @@ struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output); mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p; return p;
} }
......
...@@ -16,30 +16,31 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10> ...@@ -16,30 +16,31 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output); mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p; return p;
} }
......
...@@ -16,31 +16,33 @@ struct test_rnn_forward : verify_program<test_rnn_forward> ...@@ -16,31 +16,33 @@ struct test_rnn_forward : verify_program<test_rnn_forward>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size, auto hs =
{migraphx::op::tanh{}, migraphx::op::tanh{}}, mm->add_instruction(migraphx::op::rnn{hidden_size,
migraphx::op::rnn_direction::forward, {migraphx::op::tanh{}, migraphx::op::tanh{}},
clip}, migraphx::op::rnn_direction::forward,
seq, clip},
w, seq,
r, w,
bias, r,
und, bias,
ih); und,
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs); ih);
p.add_return({hs, lho}); auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
mm->add_return({hs, lho});
return p; return p;
} }
......
...@@ -16,31 +16,33 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10> ...@@ -16,31 +16,33 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size, auto hs =
{migraphx::op::tanh{}, migraphx::op::tanh{}}, mm->add_instruction(migraphx::op::rnn{hidden_size,
migraphx::op::rnn_direction::forward, {migraphx::op::tanh{}, migraphx::op::tanh{}},
clip}, migraphx::op::rnn_direction::forward,
seq, clip},
w, seq,
r, w,
bias, r,
und, bias,
ih); und,
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs); ih);
p.add_return({hs, lho}); auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
mm->add_return({hs, lho});
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment