Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -16,6 +16,7 @@ struct test_gru_forward : verify_program<test_gru_forward> ...@@ -16,6 +16,7 @@ struct test_gru_forward : verify_program<test_gru_forward>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
...@@ -24,26 +25,26 @@ struct test_gru_forward : verify_program<test_gru_forward> ...@@ -24,26 +25,26 @@ struct test_gru_forward : verify_program<test_gru_forward>
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
auto hs = auto hs =
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs); auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({lho, hs}); mm->add_return({lho, hs});
return p; return p;
} }
......
...@@ -16,21 +16,22 @@ struct test_gru_forward_3args : verify_program<test_gru_forward_3args> ...@@ -16,21 +16,22 @@ struct test_gru_forward_3args : verify_program<test_gru_forward_3args>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
r); r);
return p; return p;
} }
......
...@@ -16,25 +16,26 @@ struct test_gru_forward_3args_und : verify_program<test_gru_forward_3args_und> ...@@ -16,25 +16,26 @@ struct test_gru_forward_3args_und : verify_program<test_gru_forward_3args_und>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
r, r,
und, und,
und, und,
und); und);
return p; return p;
} }
......
...@@ -16,15 +16,16 @@ struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_a ...@@ -16,15 +16,16 @@ struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_a
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
p.add_instruction( mm->add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
......
...@@ -16,6 +16,7 @@ struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_ ...@@ -16,6 +16,7 @@ struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
...@@ -24,14 +25,14 @@ struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_ ...@@ -24,14 +25,14 @@ struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
p.add_instruction( mm->add_instruction(
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip}, hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
......
...@@ -16,21 +16,22 @@ struct test_gru_forward_seq1 : verify_program<test_gru_forward_seq1> ...@@ -16,21 +16,22 @@ struct test_gru_forward_seq1 : verify_program<test_gru_forward_seq1>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
r); r);
return p; return p;
} }
......
...@@ -16,21 +16,22 @@ struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args> ...@@ -16,21 +16,22 @@ struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
r); r);
return p; return p;
} }
......
...@@ -16,6 +16,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last> ...@@ -16,6 +16,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
...@@ -24,25 +25,25 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last> ...@@ -24,25 +25,25 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output); mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p; return p;
} }
......
...@@ -16,21 +16,22 @@ struct test_gru_two_outputs : verify_program<test_gru_two_outputs> ...@@ -16,21 +16,22 @@ struct test_gru_two_outputs : verify_program<test_gru_two_outputs>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto hs = p.add_instruction( auto hs = mm->add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r); r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs); auto last_hs = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs}); mm->add_return({hs, last_hs});
return p; return p;
} }
......
...@@ -6,30 +6,31 @@ ...@@ -6,30 +6,31 @@
migraphx::instruction_ref add_layernorm(migraphx::program& p, std::vector<size_t> dims) migraphx::instruction_ref add_layernorm(migraphx::program& p, std::vector<size_t> dims)
{ {
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto scale = auto scale =
p.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}}); mm->add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}}); mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto epsilon = p.add_literal(1e-12f); auto epsilon = mm->add_literal(1e-12f);
auto exponent = p.add_literal(2.0f); auto exponent = mm->add_literal(2.0f);
auto mean = p.add_instruction(migraphx::op::reduce_mean({2}), x); auto mean = mm->add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean_mbcast = p.add_instruction(migraphx::op::multibroadcast{{dims}}, mean); auto mean_mbcast = mm->add_instruction(migraphx::op::multibroadcast{{dims}}, mean);
auto sub = p.add_instruction(migraphx::op::sub{}, x, mean_mbcast); auto sub = mm->add_instruction(migraphx::op::sub{}, x, mean_mbcast);
auto exponent_mbcast = p.add_instruction(migraphx::op::multibroadcast{{dims}}, exponent); auto exponent_mbcast = mm->add_instruction(migraphx::op::multibroadcast{{dims}}, exponent);
auto pow = p.add_instruction(migraphx::op::pow{}, sub, exponent_mbcast); auto pow = mm->add_instruction(migraphx::op::pow{}, sub, exponent_mbcast);
auto var = p.add_instruction(migraphx::op::reduce_mean({2}), pow); auto var = mm->add_instruction(migraphx::op::reduce_mean({2}), pow);
auto epsilon_mbcast = auto epsilon_mbcast =
p.add_instruction(migraphx::op::multibroadcast{{1, dims.at(1), 1}}, epsilon); mm->add_instruction(migraphx::op::multibroadcast{{1, dims.at(1), 1}}, epsilon);
auto add_epsilon = p.add_instruction(migraphx::op::add{}, var, epsilon_mbcast); auto add_epsilon = mm->add_instruction(migraphx::op::add{}, var, epsilon_mbcast);
auto sqrt = p.add_instruction(migraphx::op::sqrt{}, add_epsilon); auto sqrt = mm->add_instruction(migraphx::op::sqrt{}, add_epsilon);
auto sqrt_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, sqrt); auto sqrt_mbcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, sqrt);
auto div = p.add_instruction(migraphx::op::div{}, sub, sqrt_mbcast); auto div = mm->add_instruction(migraphx::op::div{}, sub, sqrt_mbcast);
auto scale_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, scale); auto scale_mbcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, scale);
auto mul = p.add_instruction(migraphx::op::mul{}, scale_mbcast, div); auto mul = mm->add_instruction(migraphx::op::mul{}, scale_mbcast, div);
auto bias_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, bias); auto bias_mbcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, bias);
return p.add_instruction(migraphx::op::add{}, mul, bias_mbcast); return mm->add_instruction(migraphx::op::add{}, mul, bias_mbcast);
} }
struct test_layernorm : verify_program<test_layernorm> struct test_layernorm : verify_program<test_layernorm>
......
...@@ -9,8 +9,9 @@ struct test_leaky_relu : verify_program<test_leaky_relu> ...@@ -9,8 +9,9 @@ struct test_leaky_relu : verify_program<test_leaky_relu>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto* mm = p.get_main_module();
p.add_instruction(migraphx::op::leaky_relu{0.01}, x); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::op::leaky_relu{0.01}, x);
return p; return p;
} }
}; };
...@@ -9,12 +9,13 @@ struct test_less : verify_program<test_less> ...@@ -9,12 +9,13 @@ struct test_less : verify_program<test_less>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto input1 = p.add_parameter("x", s); auto input1 = mm->add_parameter("x", s);
auto input2 = p.add_parameter("y", s); auto input2 = mm->add_parameter("y", s);
auto r = p.add_instruction(migraphx::op::less{}, input1, input2); auto r = mm->add_instruction(migraphx::op::less{}, input1, input2);
p.add_return({r}); mm->add_return({r});
return p; return p;
}; };
}; };
...@@ -9,13 +9,14 @@ struct test_less_brcst : verify_program<test_less_brcst> ...@@ -9,13 +9,14 @@ struct test_less_brcst : verify_program<test_less_brcst>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = p.add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1); auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
auto r = p.add_instruction(migraphx::op::less{}, l0, bl1); auto r = mm->add_instruction(migraphx::op::less{}, l0, bl1);
p.add_return({r}); mm->add_return({r});
return p; return p;
}; };
......
...@@ -9,12 +9,13 @@ struct test_literals : verify_program<test_literals> ...@@ -9,12 +9,13 @@ struct test_literals : verify_program<test_literals>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_literal( auto* mm = p.get_main_module();
auto input = mm->add_literal(
generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}})); generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}));
auto weights = p.add_literal( auto weights = mm->add_literal(
generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}})); generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}));
auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights); auto conv = mm->add_instruction(migraphx::op::convolution{}, input, weights);
p.add_instruction(migraphx::op::relu{}, conv); mm->add_instruction(migraphx::op::relu{}, conv);
return p; return p;
} }
}; };
...@@ -9,9 +9,10 @@ struct test_log : verify_program<test_log> ...@@ -9,9 +9,10 @@ struct test_log : verify_program<test_log>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s)); auto x = mm->add_instruction(migraphx::op::abs{}, mm->add_parameter("x", s));
p.add_instruction(migraphx::op::log{}, x); mm->add_instruction(migraphx::op::log{}, x);
return p; return p;
} }
}; };
...@@ -10,9 +10,10 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>> ...@@ -10,9 +10,10 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{T, {10, 4, 2080, 6}}; migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s); auto param = mm->add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); mm->add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p; return p;
} }
......
...@@ -16,21 +16,22 @@ struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args> ...@@ -16,21 +16,22 @@ struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}}; {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}}; {num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
p.add_instruction(migraphx::op::lstm{hidden_size, mm->add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
r); r);
return p; return p;
} }
......
...@@ -16,16 +16,17 @@ struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und> ...@@ -16,16 +16,17 @@ struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}}; {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}}; {num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
p.add_instruction( mm->add_instruction(
migraphx::op::gru{hidden_size, migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
......
...@@ -16,15 +16,16 @@ struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default ...@@ -16,15 +16,16 @@ struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}}; {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}}; {num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
p.add_instruction( mm->add_instruction(
migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
......
...@@ -16,6 +16,7 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul ...@@ -16,6 +16,7 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}}; {num_dirct, 4 * hidden_size, input_size}};
...@@ -25,24 +26,24 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul ...@@ -25,24 +26,24 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
std::vector<int> sl_data(batch_size, 2); std::vector<int> sl_data(batch_size, 2);
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data}); auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data});
p.add_instruction(migraphx::op::lstm{hidden_size, mm->add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}}, {migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
r, r,
bias, bias,
sql, sql,
ih); ih);
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment