"MacGPUEnv.md" did not exist on "cb5530032831726bd7358342e7d8913536be8033"
Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -9,14 +9,15 @@ struct test_gemm_copy : verify_program<test_gemm_copy> ...@@ -9,14 +9,15 @@ struct test_gemm_copy : verify_program<test_gemm_copy>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = p.add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = p.add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = p.add_parameter("c", sc); auto pc = mm->add_parameter("c", sc);
auto dr = p.add_instruction(migraphx::op::dot{}, pa, pb, pc); auto dr = mm->add_instruction(migraphx::op::dot{}, pa, pb, pc);
p.add_instruction(migraphx::op::add{}, dr, dr); mm->add_instruction(migraphx::op::add{}, dr, dr);
return p; return p;
} }
......
...@@ -9,9 +9,10 @@ struct test_gemm_ex : verify_program<test_gemm_ex> ...@@ -9,9 +9,10 @@ struct test_gemm_ex : verify_program<test_gemm_ex>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 5}}); auto* mm = p.get_main_module();
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 5}});
p.add_instruction(migraphx::op::dot{}, a, b); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}});
mm->add_instruction(migraphx::op::dot{}, a, b);
return p; return p;
} }
}; };
...@@ -9,9 +9,10 @@ struct test_gemm_half : verify_program<test_gemm_half> ...@@ -9,9 +9,10 @@ struct test_gemm_half : verify_program<test_gemm_half>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::half_type, {4, 5}}); auto* mm = p.get_main_module();
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::half_type, {5, 3}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::half_type, {4, 5}});
p.add_instruction(migraphx::op::dot{}, a, b); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::half_type, {5, 3}});
mm->add_instruction(migraphx::op::dot{}, a, b);
return p; return p;
} }
}; };
...@@ -9,11 +9,12 @@ struct test_gemm_ld //: verify_program<test_gemm_ld> ...@@ -9,11 +9,12 @@ struct test_gemm_ld //: verify_program<test_gemm_ld>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto a = auto a =
p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}, {10, 1}}); mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}, {10, 1}});
auto b = auto b =
p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}, {20, 1}}); mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}, {20, 1}});
p.add_instruction(migraphx::op::dot{}, a, b); mm->add_instruction(migraphx::op::dot{}, a, b);
return p; return p;
} }
}; };
...@@ -9,10 +9,11 @@ struct test_gemm_transposea : verify_program<test_gemm_transposea> ...@@ -9,10 +9,11 @@ struct test_gemm_transposea : verify_program<test_gemm_transposea>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); auto* mm = p.get_main_module();
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}});
auto at = p.add_instruction(migraphx::op::transpose{{1, 0}}, a); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}});
p.add_instruction(migraphx::op::dot{}, at, b); auto at = mm->add_instruction(migraphx::op::transpose{{1, 0}}, a);
mm->add_instruction(migraphx::op::dot{}, at, b);
return p; return p;
} }
}; };
...@@ -9,10 +9,11 @@ struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex> ...@@ -9,10 +9,11 @@ struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); auto* mm = p.get_main_module();
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}});
auto at = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, a); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}});
p.add_instruction(migraphx::op::dot{}, at, b); auto at = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, a);
mm->add_instruction(migraphx::op::dot{}, at, b);
return p; return p;
} }
}; };
...@@ -9,11 +9,12 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab> ...@@ -9,11 +9,12 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); auto* mm = p.get_main_module();
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}});
auto at = p.add_instruction(migraphx::op::transpose{{1, 0}}, a); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}});
auto bt = p.add_instruction(migraphx::op::transpose{{1, 0}}, b); auto at = mm->add_instruction(migraphx::op::transpose{{1, 0}}, a);
p.add_instruction(migraphx::op::dot{}, at, bt); auto bt = mm->add_instruction(migraphx::op::transpose{{1, 0}}, b);
mm->add_instruction(migraphx::op::dot{}, at, bt);
return p; return p;
} }
}; };
...@@ -9,10 +9,11 @@ struct test_gemm_transposeb : verify_program<test_gemm_transposeb> ...@@ -9,10 +9,11 @@ struct test_gemm_transposeb : verify_program<test_gemm_transposeb>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); auto* mm = p.get_main_module();
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto bt = p.add_instruction(migraphx::op::transpose{{1, 0}}, b); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}});
p.add_instruction(migraphx::op::dot{}, a, bt); auto bt = mm->add_instruction(migraphx::op::transpose{{1, 0}}, b);
mm->add_instruction(migraphx::op::dot{}, a, bt);
return p; return p;
} }
}; };
...@@ -9,10 +9,11 @@ struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex> ...@@ -9,10 +9,11 @@ struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}}); auto* mm = p.get_main_module();
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}});
auto bt = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, b); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}});
p.add_instruction(migraphx::op::dot{}, a, bt); auto bt = mm->add_instruction(migraphx::op::transpose{{0, 2, 1}}, b);
mm->add_instruction(migraphx::op::dot{}, a, bt);
return p; return p;
} }
}; };
...@@ -10,12 +10,13 @@ struct test_global_avg_pooling : verify_program<test_global_avg_pooling> ...@@ -10,12 +10,13 @@ struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto input = auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{"average"};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input); mm->add_instruction(op, input);
return p; return p;
} }
}; };
...@@ -10,12 +10,13 @@ struct test_global_max_pooling : verify_program<test_global_max_pooling> ...@@ -10,12 +10,13 @@ struct test_global_max_pooling : verify_program<test_global_max_pooling>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto input = auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{"max"};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input); mm->add_instruction(op, input);
return p; return p;
} }
}; };
...@@ -9,12 +9,13 @@ struct test_greater : verify_program<test_greater> ...@@ -9,12 +9,13 @@ struct test_greater : verify_program<test_greater>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto input1 = p.add_parameter("x", s); auto input1 = mm->add_parameter("x", s);
auto input2 = p.add_parameter("y", s); auto input2 = mm->add_parameter("y", s);
auto r = p.add_instruction(migraphx::op::greater{}, input1, input2); auto r = mm->add_instruction(migraphx::op::greater{}, input1, input2);
p.add_return({r}); mm->add_return({r});
return p; return p;
}; };
}; };
...@@ -9,13 +9,14 @@ struct test_greater_brcst : verify_program<test_greater_brcst> ...@@ -9,13 +9,14 @@ struct test_greater_brcst : verify_program<test_greater_brcst>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = p.add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1); auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
auto r = p.add_instruction(migraphx::op::greater{}, l0, bl1); auto r = mm->add_instruction(migraphx::op::greater{}, l0, bl1);
p.add_return({r}); mm->add_return({r});
return p; return p;
}; };
......
...@@ -9,13 +9,14 @@ struct test_group_conv : verify_program<test_group_conv> ...@@ -9,13 +9,14 @@ struct test_group_conv : verify_program<test_group_conv>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto input = auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}});
auto weights = auto weights =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}});
migraphx::op::convolution op; migraphx::op::convolution op;
op.group = 4; op.group = 4;
p.add_instruction(op, input, weights); mm->add_instruction(op, input, weights);
return p; return p;
} }
}; };
...@@ -16,6 +16,7 @@ struct test_gru_bidirct : verify_program<test_gru_bidirct> ...@@ -16,6 +16,7 @@ struct test_gru_bidirct : verify_program<test_gru_bidirct>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
...@@ -24,26 +25,26 @@ struct test_gru_bidirct : verify_program<test_gru_bidirct> ...@@ -24,26 +25,26 @@ struct test_gru_bidirct : verify_program<test_gru_bidirct>
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
auto hs = auto hs =
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs); auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho}); mm->add_return({hs, lho});
return p; return p;
} }
......
...@@ -16,21 +16,22 @@ struct test_gru_bidirct_3args : verify_program<test_gru_bidirct_3args> ...@@ -16,21 +16,22 @@ struct test_gru_bidirct_3args : verify_program<test_gru_bidirct_3args>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
r); r);
return p; return p;
} }
......
...@@ -16,25 +16,26 @@ struct test_gru_bidirct_3args_und : verify_program<test_gru_bidirct_3args_und> ...@@ -16,25 +16,26 @@ struct test_gru_bidirct_3args_und : verify_program<test_gru_bidirct_3args_und>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
r, r,
und, und,
und, und,
und); und);
return p; return p;
} }
......
...@@ -16,15 +16,16 @@ struct test_gru_bidirct_default_actv : verify_program<test_gru_bidirct_default_a ...@@ -16,15 +16,16 @@ struct test_gru_bidirct_default_actv : verify_program<test_gru_bidirct_default_a
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
p.add_instruction( mm->add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
......
...@@ -16,6 +16,7 @@ struct test_gru_bidirct_default_actv1 : verify_program<test_gru_bidirct_default_ ...@@ -16,6 +16,7 @@ struct test_gru_bidirct_default_actv1 : verify_program<test_gru_bidirct_default_
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
...@@ -24,23 +25,23 @@ struct test_gru_bidirct_default_actv1 : verify_program<test_gru_bidirct_default_ ...@@ -24,23 +25,23 @@ struct test_gru_bidirct_default_actv1 : verify_program<test_gru_bidirct_default_
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}}, {migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
return p; return p;
} }
......
...@@ -16,21 +16,22 @@ struct test_gru_bidirct_seq1 : verify_program<test_gru_bidirct_seq1> ...@@ -16,21 +16,22 @@ struct test_gru_bidirct_seq1 : verify_program<test_gru_bidirct_seq1>
float clip = 0.0f; float clip = 0.0f;
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
r); r);
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment