Commit 82fe652a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 9d65c568
...@@ -2726,8 +2726,8 @@ TEST_CASE(gru_bidirectional) ...@@ -2726,8 +2726,8 @@ TEST_CASE(gru_bidirectional)
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::gru::bidirectional,
...@@ -2743,15 +2743,15 @@ TEST_CASE(gru_bidirectional) ...@@ -2743,15 +2743,15 @@ TEST_CASE(gru_bidirectional)
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{
0.0863793, -0.227845, 0.0283059, -0.258645, 0.14187, 0.43541, 0.190748, 0.0863793, -0.227845, 0.0283059, -0.258645, 0.14187, 0.43541, 0.190748,
-0.530196, -0.440444, 0.293767, 0.0402142, 0.0788687, -0.013, -0.233298, -0.530196, -0.440444, 0.293767, 0.0402142, 0.0788687, -0.013, -0.233298,
-0.0739615, 0.467104, 0.446285, 0.306097, 0.125636, 0.272524, 0.0949838, -0.0739615, 0.467104, 0.446285, 0.306097, 0.125636, 0.272524, 0.0949838,
0.0522264, -0.0872712, -0.084203, 0.140013, 0.12739, -0.0111171, -0.431119, 0.0522264, -0.0872712, -0.084203, 0.140013, 0.12739, -0.0111171, -0.431119,
-0.468382, 0.388067, -0.109174, -0.119064, -0.0242958, -0.180555, 0.118983, -0.468382, 0.388067, -0.109174, -0.119064, -0.0242958, -0.180555, 0.118983,
0.341578, 0.275472, 0.0853083, 0.332205, -0.0498387, 0.140338, 0.0319435, 0.341578, 0.275472, 0.0853083, 0.332205, -0.0498387, 0.140338, 0.0319435,
0.247019, 0.275848, -0.158223, 0.0495464, -0.0681034, -0.418158, -0.523234, 0.247019, 0.275848, -0.158223, 0.0495464, -0.0681034, -0.418158, -0.523234,
0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407, 0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407,
0.198708, 0.0695644, 0.211621, 0.00246037}; 0.198708, 0.0695644, 0.211621, 0.00246037};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
...@@ -2759,9 +2759,9 @@ TEST_CASE(gru_bidirectional) ...@@ -2759,9 +2759,9 @@ TEST_CASE(gru_bidirectional)
// 4 args (bias is used) // 4 args (bias is used)
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
...@@ -2779,18 +2779,15 @@ TEST_CASE(gru_bidirectional) ...@@ -2779,18 +2779,15 @@ TEST_CASE(gru_bidirectional)
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{
-0.156667, -0.248473, 0.0255282, -0.24566, 0.211589, -0.156667, -0.248473, 0.0255282, -0.24566, 0.211589, 0.192707, 0.253025,
0.192707, 0.253025, -0.515283, -0.414174, 0.227127, -0.515283, -0.414174, 0.227127, 0.124773, 0.284532, -0.203929, -0.120517,
0.124773, 0.284532, -0.203929, -0.120517, -0.2794, -0.2794, 0.547635, 0.518549, 0.0447674, 0.258461, 0.0502881, -0.219516,
0.547635, 0.518549, 0.0447674, 0.258461, 0.0502881, 0.0927382, -0.0760062, -0.0906231, 0.237615, -0.215638, 0.0128074, -0.425813,
-0.219516, 0.0927382, -0.0760062, -0.0906231, 0.237615, -0.433378, 0.375383, -0.0381738, 0.117793, -0.180851, -0.0841245, -0.116649,
-0.215638, 0.0128074, -0.425813, -0.433378, 0.375383, 0.419469, 0.393515, -0.076395, 0.427436, -0.264071, -0.185829, 0.0483585,
-0.0381738, 0.117793, -0.180851, -0.0841245, -0.116649, 0.242955, 0.25233, 0.0148512, -0.304127, -0.0616653, -0.411568, -0.491748,
0.419469, 0.393515, -0.076395, 0.427436, -0.264071, 0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008,
-0.185829, 0.0483585, 0.242955, 0.25233, 0.0148512, 0.248674, -0.0295413, 0.291437, -0.165005};
-0.304127, -0.0616653, -0.411568, -0.491748, 0.476508,
-0.313413, -0.0361821, -0.173037, -0.235731, -0.163113,
0.349008, 0.248674, -0.0295413, 0.291437, -0.165005};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
...@@ -2799,10 +2796,10 @@ TEST_CASE(gru_bidirectional) ...@@ -2799,10 +2796,10 @@ TEST_CASE(gru_bidirectional)
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::gru::bidirectional,
...@@ -2821,42 +2818,35 @@ TEST_CASE(gru_bidirectional) ...@@ -2821,42 +2818,35 @@ TEST_CASE(gru_bidirectional)
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{
0.248571, 0.0982155, 0.00808877, 0.0986508, 0.0969705, 0.248571, 0.0982155, 0.00808877, 0.0986508, 0.0969705, 0.434692, -0.141696,
0.434692, -0.141696, -0.164271, -0.121157, 0.863222, -0.164271, -0.121157, 0.863222, -0.0718357, 0.137711, 0.109221, -0.00207995,
-0.0718357, 0.137711, 0.109221, -0.00207995, 0.0331223, 0.0331223, 0.262705, 0.346587, 0.457158, 0.240744, 0.404261, 0.222779,
0.262705, 0.346587, 0.457158, 0.240744, 0.404261, 0.179757, -0.0845316, 0.0690347, 0.10204, 0.100155, -0.190286, -0.122062,
0.222779, 0.179757, -0.0845316, 0.0690347, 0.10204, -0.274379, 0.547281, -0.226753, -0.0397069, 0.120404, 0.171299, 0.259989,
0.100155, -0.190286, -0.122062, -0.274379, 0.547281, 0.0864604, 0.111322, 0.331784, 0.604653, 0.181017, 0.237426, 0.0911999,
-0.226753, -0.0397069, 0.120404, 0.171299, 0.259989, 0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354,
0.0864604, 0.111322, 0.331784, 0.604653, 0.181017, 0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917,
0.237426, 0.0911999, 0.233106, 0.32996, -0.17175, -0.0339407, 0.413089, 0.721238, 0.431879};
0.0190231, -0.154805, -0.205631, -0.405354, 0.519054,
-0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883,
-0.0977917, -0.0339407, 0.413089, 0.721238, 0.431879};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
// no activation function specified, so default is used. // no activation function specified, so default is used.
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs = auto concat_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hidden_size, migraphx::op::gru{hidden_size, {}, migraphx::op::gru::bidirectional, clip, 1},
{}, seq,
migraphx::op::gru::bidirectional, w,
clip, r,
1}, bias,
seq, und,
w, ih);
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::gru_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
...@@ -2874,82 +2864,71 @@ TEST_CASE(gru_bidirectional) ...@@ -2874,82 +2864,71 @@ TEST_CASE(gru_bidirectional)
// 1 activation function (sigmoid) specified // 1 activation function (sigmoid) specified
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(
{migraphx::op::sigmoid{}}, migraphx::op::gru{
migraphx::op::gru::bidirectional, hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::bidirectional, clip, 0},
clip, seq,
0}, w,
seq, r,
w, bias,
r, und,
bias, ih);
und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{
0.325495, 0.469214, 0.164517, 0.585327, 0.328398, 0.325495, 0.469214, 0.164517, 0.585327, 0.328398, 0.457928, 0.065011, 0.35986,
0.457928, 0.065011, 0.35986, 0.545029, 0.859425, 0.545029, 0.859425, 0.427923, 0.667133, 0.41591, 0.540971, 0.365475, 0.482058,
0.427923, 0.667133, 0.41591, 0.540971, 0.365475, 0.565495, 0.556993, 0.607649, 0.543627, 0.428915, 0.537405, 0.306046, 0.518399,
0.482058, 0.565495, 0.556993, 0.607649, 0.543627, 0.403561, 0.410694, 0.301163, 0.407397, 0.471334, 0.726446, 0.309389, 0.612072,
0.428915, 0.537405, 0.306046, 0.518399, 0.403561, 0.360619, 0.590861, 0.366545, 0.367001, 0.433829, 0.501275, 0.72481, 0.512745,
0.410694, 0.301163, 0.407397, 0.471334, 0.726446, 0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275,
0.309389, 0.612072, 0.360619, 0.590861, 0.366545, 0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646,
0.367001, 0.433829, 0.501275, 0.72481, 0.512745, 0.132732, 0.477083, 0.802206, 0.626802};
0.463795, 0.539649, 0.487682, 0.554471, 0.395916,
0.430744, 0.415923, 0.424275, 0.409655, 0.698256,
0.126883, 0.554374, 0.216137, 0.671491, 0.263833,
0.0678646, 0.132732, 0.477083, 0.802206, 0.626802};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
// 1 activation function (tanh) specified // 1 activation function (tanh) specified
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(
{migraphx::op::tanh{}}, migraphx::op::gru{
migraphx::op::gru::bidirectional, hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip, 1},
clip, seq,
1}, w,
seq, r,
w, bias,
r, und,
bias, ih);
und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{
0.0919632, -0.398302, -0.0267752, -0.326771, 0.401983, 0.0919632, -0.398302, -0.0267752, -0.326771, 0.401983, 0.949841, 0.557779,
0.949841, 0.557779, -0.745259, -1.52726, 0.946066, -0.745259, -1.52726, 0.946066, 0.330446, 0.301982, -0.443763, -0.0655817,
0.330446, 0.301982, -0.443763, -0.0655817, -0.326473, -0.326473, 0.861394, 0.560799, -0.101768, 0.145142, 0.128956, -0.329758,
0.861394, 0.560799, -0.101768, 0.145142, 0.128956, 0.458253, -0.339208, 0.289109, 0.36728, -1.09574, -0.181394, -0.575781,
-0.329758, 0.458253, -0.339208, 0.289109, 0.36728, -0.823083, 0.804262, -0.0965933, 0.20405, -0.430215, 0.00884668, 0.0716857,
-1.09574, -0.181394, -0.575781, -0.823083, 0.804262, 0.844222, 0.516472, -0.191571, 0.596968, -0.545405, -0.336693, -0.0280516,
-0.0965933, 0.20405, -0.430215, 0.00884668, 0.0716857, 0.339058, 1.00367, 0.12655, -0.0984504, -0.174945, -0.5365, 0.183188,
0.844222, 0.516472, -0.191571, 0.596968, -0.545405, 0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419,
-0.336693, -0.0280516, 0.339058, 1.00367, 0.12655, 0.759629, 0.000258222, 0.350835, -0.682684};
-0.0984504, -0.174945, -0.5365, 0.183188, 0.66716,
-0.704461, -0.393346, -0.627123, 0.210395, 0.0563026,
0.31419, 0.759629, 0.000258222, 0.350835, -0.682684};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
...@@ -2957,34 +2936,34 @@ TEST_CASE(gru_bidirectional) ...@@ -2957,34 +2936,34 @@ TEST_CASE(gru_bidirectional)
// 3 activation functions specified // 3 activation functions specified
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs = p.add_instruction(migraphx::op::gru{hidden_size, auto concat_hs = p.add_instruction(
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::gru{hidden_size,
migraphx::op::gru::bidirectional, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
clip, migraphx::op::gru::bidirectional,
1}, clip,
seq, 1},
w, seq,
r, w,
bias, r,
und, bias,
ih); und,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::gru_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{0.351019, 0.474363, 0.570719, 0.717703, 0.468843,
0.351019, 0.474363, 0.570719, 0.717703, 0.468843, 1.15142, 0.457633, 0.300962, 0.361245, 0.666199,
1.15142, 0.457633, 0.300962, 0.361245, 0.666199, 0.330446, 0.301982, -0.443763, -0.0655817, -0.326473,
0.330446, 0.301982, -0.443763, -0.0655817, -0.326473, 0.861394, 0.560799, -0.101768, 0.145142, 0.128956};
0.861394, 0.560799, -0.101768, 0.145142, 0.128956};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
...@@ -2998,8 +2977,10 @@ TEST_CASE(gru_bidirectional) ...@@ -2998,8 +2977,10 @@ TEST_CASE(gru_bidirectional)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, {migraphx::op::sigmoid{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::gru::bidirectional,
clip, clip,
1}, 1},
...@@ -3029,19 +3010,19 @@ TEST_CASE(gru_bidirectional) ...@@ -3029,19 +3010,19 @@ TEST_CASE(gru_bidirectional)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
// seq length of 1 // seq length of 1
{ {
migraphx::program p; migraphx::program p;
seq_len = 1; seq_len = 1;
migraphx::shape in_shape_one{migraphx::shape::float_type, {seq_len, batch_size,input_size}}; migraphx::shape in_shape_one{migraphx::shape::float_type,
std::vector<float> input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504}; {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one}); std::vector<float> input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504};
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::gru::bidirectional,
...@@ -3059,11 +3040,10 @@ TEST_CASE(gru_bidirectional) ...@@ -3059,11 +3040,10 @@ TEST_CASE(gru_bidirectional)
std::vector<float> hs_data; std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683,
0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273, -0.135177, -0.0800739, 0.903659,
0.214342, -0.0454273, -0.135177, -0.0800739, 0.903659, -0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078, -0.144492, -0.0115366, 0.409153, 0.487015, 0.550755};
-0.144492, -0.0115366, 0.409153, 0.487015, 0.550755};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment