Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
......@@ -16,6 +16,7 @@ struct test_gru_forward : verify_program<test_gru_forward>
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
......@@ -24,26 +25,26 @@ struct test_gru_forward : verify_program<test_gru_forward>
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({lho, hs});
mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
mm->add_return({lho, hs});
return p;
}
......
......@@ -16,21 +16,22 @@ struct test_gru_forward_3args : verify_program<test_gru_forward_3args>
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
return p;
}
......
......@@ -16,25 +16,26 @@ struct test_gru_forward_3args_und : verify_program<test_gru_forward_3args_und>
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und);
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und);
return p;
}
......
......@@ -16,15 +16,16 @@ struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_a
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
mm->add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
......
......@@ -16,6 +16,7 @@ struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
......@@ -24,14 +25,14 @@ struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
p.add_instruction(
mm->add_instruction(
migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip},
seq,
......
......@@ -16,21 +16,22 @@ struct test_gru_forward_seq1 : verify_program<test_gru_forward_seq1>
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
return p;
}
......
......@@ -16,21 +16,22 @@ struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args>
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r);
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r);
return p;
}
......
......@@ -16,6 +16,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
......@@ -24,25 +25,25 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......
......@@ -16,21 +16,22 @@ struct test_gru_two_outputs : verify_program<test_gru_two_outputs>
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto hs = p.add_instruction(
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto hs = mm->add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
auto last_hs = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
mm->add_return({hs, last_hs});
return p;
}
......
......@@ -6,30 +6,31 @@
migraphx::instruction_ref add_layernorm(migraphx::program& p, std::vector<size_t> dims)
{
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto scale =
p.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
mm->add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto epsilon = p.add_literal(1e-12f);
auto exponent = p.add_literal(2.0f);
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto epsilon = mm->add_literal(1e-12f);
auto exponent = mm->add_literal(2.0f);
auto mean = p.add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean_mbcast = p.add_instruction(migraphx::op::multibroadcast{{dims}}, mean);
auto sub = p.add_instruction(migraphx::op::sub{}, x, mean_mbcast);
auto exponent_mbcast = p.add_instruction(migraphx::op::multibroadcast{{dims}}, exponent);
auto pow = p.add_instruction(migraphx::op::pow{}, sub, exponent_mbcast);
auto var = p.add_instruction(migraphx::op::reduce_mean({2}), pow);
auto mean = mm->add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean_mbcast = mm->add_instruction(migraphx::op::multibroadcast{{dims}}, mean);
auto sub = mm->add_instruction(migraphx::op::sub{}, x, mean_mbcast);
auto exponent_mbcast = mm->add_instruction(migraphx::op::multibroadcast{{dims}}, exponent);
auto pow = mm->add_instruction(migraphx::op::pow{}, sub, exponent_mbcast);
auto var = mm->add_instruction(migraphx::op::reduce_mean({2}), pow);
auto epsilon_mbcast =
p.add_instruction(migraphx::op::multibroadcast{{1, dims.at(1), 1}}, epsilon);
auto add_epsilon = p.add_instruction(migraphx::op::add{}, var, epsilon_mbcast);
auto sqrt = p.add_instruction(migraphx::op::sqrt{}, add_epsilon);
auto sqrt_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, sqrt);
auto div = p.add_instruction(migraphx::op::div{}, sub, sqrt_mbcast);
auto scale_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, scale);
auto mul = p.add_instruction(migraphx::op::mul{}, scale_mbcast, div);
auto bias_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, bias);
return p.add_instruction(migraphx::op::add{}, mul, bias_mbcast);
mm->add_instruction(migraphx::op::multibroadcast{{1, dims.at(1), 1}}, epsilon);
auto add_epsilon = mm->add_instruction(migraphx::op::add{}, var, epsilon_mbcast);
auto sqrt = mm->add_instruction(migraphx::op::sqrt{}, add_epsilon);
auto sqrt_mbcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, sqrt);
auto div = mm->add_instruction(migraphx::op::div{}, sub, sqrt_mbcast);
auto scale_mbcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, scale);
auto mul = mm->add_instruction(migraphx::op::mul{}, scale_mbcast, div);
auto bias_mbcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, bias);
return mm->add_instruction(migraphx::op::add{}, mul, bias_mbcast);
}
struct test_layernorm : verify_program<test_layernorm>
......
......@@ -9,8 +9,9 @@ struct test_leaky_relu : verify_program<test_leaky_relu>
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraphx::op::leaky_relu{0.01}, x);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::op::leaky_relu{0.01}, x);
return p;
}
};
......@@ -9,12 +9,13 @@ struct test_less : verify_program<test_less>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto input1 = p.add_parameter("x", s);
auto input2 = p.add_parameter("y", s);
auto r = p.add_instruction(migraphx::op::less{}, input1, input2);
p.add_return({r});
auto input1 = mm->add_parameter("x", s);
auto input2 = mm->add_parameter("y", s);
auto r = mm->add_instruction(migraphx::op::less{}, input1, input2);
mm->add_return({r});
return p;
};
};
......@@ -9,13 +9,14 @@ struct test_less_brcst : verify_program<test_less_brcst>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 = p.add_parameter("x", s0);
auto l0 = mm->add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = p.add_parameter("y", s1);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
auto r = p.add_instruction(migraphx::op::less{}, l0, bl1);
p.add_return({r});
auto l1 = mm->add_parameter("y", s1);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
auto r = mm->add_instruction(migraphx::op::less{}, l0, bl1);
mm->add_return({r});
return p;
};
......
......@@ -9,12 +9,13 @@ struct test_literals : verify_program<test_literals>
migraphx::program create_program() const
{
migraphx::program p;
auto input = p.add_literal(
auto* mm = p.get_main_module();
auto input = mm->add_literal(
generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}));
auto weights = p.add_literal(
auto weights = mm->add_literal(
generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}));
auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights);
p.add_instruction(migraphx::op::relu{}, conv);
auto conv = mm->add_instruction(migraphx::op::convolution{}, input, weights);
mm->add_instruction(migraphx::op::relu{}, conv);
return p;
}
};
......@@ -9,9 +9,10 @@ struct test_log : verify_program<test_log>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
p.add_instruction(migraphx::op::log{}, x);
auto x = mm->add_instruction(migraphx::op::abs{}, mm->add_parameter("x", s));
mm->add_instruction(migraphx::op::log{}, x);
return p;
}
};
......@@ -10,9 +10,10 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
auto param = mm->add_parameter("0", s);
mm->add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
......
......@@ -16,21 +16,22 @@ struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args>
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r);
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
mm->add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r);
return p;
}
......
......@@ -16,16 +16,17 @@ struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und>
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
mm->add_instruction(
migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
......
......@@ -16,15 +16,16 @@ struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
mm->add_instruction(
migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
......
......@@ -16,6 +16,7 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
......@@ -25,24 +26,24 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
std::vector<int> sl_data(batch_size, 2);
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
sql,
ih);
mm->add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
sql,
ih);
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment