#include "verify_program.hpp" #include #include #include #include #include struct test_lstm_forward_hs : verify_program { migraphx::program create_program() const { int batch_size = 2; int seq_len = 3; int hidden_size = 5; int input_size = 8; int num_dirct = 1; float clip = 0.0f; migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; auto seq = mm->add_parameter("seq", in_shape); auto w = mm->add_parameter("w", w_shape); auto r = mm->add_parameter("r", r_shape); auto bias = mm->add_parameter("bias", b_shape); auto ih = mm->add_parameter("ih", ih_shape); auto ic = mm->add_parameter("ic", ic_shape); auto pph = mm->add_parameter("pph", pph_shape); auto und = mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction( migraphx::make_op( "lstm", {{"hidden_size", hidden_size}, {"actv_func", migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), migraphx::make_op("tanh"), migraphx::make_op("tanh")})}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, {"clip", clip}}), seq, w, r, bias, und, ih, ic, pph); return p; } std::string section() const { return "rnn"; } };