Commit e1876299 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add cpu test for the rnn operator.

parent e551275d
......@@ -798,7 +798,7 @@ struct onnx_parser
// For RNN, LSTM, and GRU operators, one of the input arguments
// is prim::Undefined, and it is ignored by protobuf. We use a
// hack to ignore this argument for these three operators
std::string op_type = node.op_type();
const std::string op_type = node.op_type();
if((op_type == "RNN" || op_type == "LSTM" || op_type == "GRU") &&
input.empty() == true)
{
......
......@@ -1347,101 +1347,291 @@ TEST_CASE(min_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
/*
TEST_CASE(rnn_test)
TEST_CASE(rnn_forward)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> wf_data{0.4691, 0.3185, -0.2227 ,0.4423, -0.0609, -0.2803, 0.1744, 0.3146, 0.4049, -0.3973, -0.0890, -0.1636};
std::vector<float> rf_data{-0.0456, 0.1061, 0.1574, -0.4928, -0.4300, -0.1909, -0.0225, -0.2668, 0.1840, -0.4453, -0.4896, 0.1302, -0.0929, 0.3545, -0.4981, 0.0616};
std::vector<float> biasf_data{-0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
float clip = 0.0f;
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
size_t hidden_size = 8;
size_t input_size = 6;
size_t batch_size = 2;
size_t seq_len = 5;
migraphx::shape hidden_shape{migraphx::shape::float_type, {1, batch_size, hidden_size}};
migraphx::shape input_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wf_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rf_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
clip},
seq,
w,
r,
bias,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> input(input_shape.elements(), 0.0);
input[0] = input[1] = 1.0;
std::vector<float> init_hidden(hidden_shape.elements(), 0.0);
std::vector<float> hs_data_gold{0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, 0.13104209, -0.18736027,
0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wf_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rf_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data});
auto out_hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
clip},
seq,
w,
r,
bias,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
m["input"] = migraphx::argument(input_shape, input.data());
auto resarg = p.eval(m);
std::vector<float> res;
resarg.visit([&](auto output) { res.assign(output.begin(), output.end()); } );
std::vector<float> res_gold{
0.596363, -0.274248, 0.714484, 0.282515, 0.0938349,
0.185406, 0.283227, -0.482086, 0.265265, -0.523217,
0.50433, 0.400934, -0.34513, 0.114924, 0.0392658,
-0.0976029, 0.364322, -0.567117, 0.538775, 0.314859,
-0.478676, 0.51778, -0.286718, -0.0478341, 0.339601,
-0.380976, 0.628219, 0.222791, -0.271949, 0.490674,
-0.234456, -0.224984, 0.456527, -0.454559, 0.546034,
-0.0389027, -0.307475, 0.561003, -0.245673, -0.0776644,
0.447162, -0.52013, 0.511913, 0.0324621, -0.380515,
0.500777, -0.225695, -0.0193589, 0.458955, -0.531746,
0.448536, -0.087655, -0.430165, 0.551379, -0.161603,
-0.0165391, 0.447551, -0.491717, 0.484796, -0.0699652,
-0.3941, 0.561967, -0.168543, -0.0661258, 0.465925,
-0.499277, 0.45216, -0.103005, -0.392837, 0.584424,
-0.189044, -0.0388068, 0.468369, -0.512927, 0.449144,
-0.0900977, -0.400401, 0.573534, -0.19617, -0.0208253};
EXPECT(migraphx::verify_range(res, res_gold));
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(rnn_reverse)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> wr_data{-0.0296, -0.1341, 0.1761, -0.2325, -0.0717, 0.1852, 0.2720, 0.1471, -0.1097, 0.3363, -0.0587, -0.2302};
std::vector<float> rr_data{ 0.2528, -0.2333, 0.3973, 0.1593, -0.0388, 0.1702, 0.3829, -0.0712, -0.1668, 0.3074, -0.2854, 0.4049, -0.3737, -0.1051, 0.4482, -0.2841};
std::vector<float> biasr_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
float clip = 0.0f;
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
size_t hidden_size = 6;
size_t input_size = 4;
size_t batch_size = 2;
size_t seq_len = 5;
migraphx::shape hidden_shape{migraphx::shape::float_type, {6, batch_size, hidden_size}};
migraphx::shape input_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wr_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rr_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> input(input_shape.elements(), 0.0);
input[0] = input[1] = 1.0;
std::vector<float> init_hidden(hidden_shape.elements(), 0.0);
std::vector<float> hs_data_gold{-0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635, 0.14803654,
-0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031, -0.20639211, 0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wr_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rr_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data});
auto out_hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
m["input"] = migraphx::argument(input_shape, &input[0]);
auto resarg = p.eval(m);
std::vector<float> res;
resarg.visit([&](auto output) { res.assign(output.begin(), output.end()); } );
std::vector<float> res_gold{
-0.0890872, -0.0558751, 0.185233, 0.452857, 0.104082,
0.432953, 0.274236, 0.186055, -0.367716, 0.266761,
-0.28489, 0.498758, 0.0140574, -0.122377, 0.278067,
0.469699, 0.216743, 0.258926, 0.269785, 0.328379,
-0.576081, 0.11672, -0.452062, 0.603549, 0.472625,
0.120929, 0.350331, 0.502138, 0.103585, 0.128486,
0.0210318, 0.338759, -0.654448, 0.37656, -0.359715,
0.424365, 0.449677, 0.130903, 0.354359, 0.59317,
0.189543, 0.201865, 0.126288, 0.31099, -0.681538,
0.275407, -0.406133, 0.450767, 0.305638, 0.14942,
0.309857, 0.722745, 0.361199, -0.00963601, 0.397046,
0.264047, -0.539317, 0.0690505, -0.321901, 0.566638,
0.406511, 0.231472, 0.320225, 0.737927, 0.372938,
0.00762333, 0.349881, 0.280791, -0.541838, 0.128319,
-0.266702, 0.536205, 0.509004, 0.361068, 0.42431,
0.767474, 0.368881, 0.0753035, 0.141155, 0.219692,
-0.643801, 0.281643, -0.330984, 0.397033, 0.494424,
0.38013, 0.434627, 0.795404, 0.391589, 0.0102068,
0.166358, 0.226248, -0.608175, 0.302622, -0.349646,
0.375506, 0.546918, 0.22908, 0.40025, 0.806049,
0.424462, -0.0352604, 0.528827, -0.0372434, -0.573789,
-0.0541837, -0.194983, 0.552972, 0.553695, 0.263657,
0.432448, 0.815763, 0.412716, -0.0389366, 0.52391,
-0.0256845, -0.577296, -0.0570545, -0.219738, 0.561644};
EXPECT(migraphx::verify_range(res, res_gold));
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{-0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635, 0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(rnn_bidirectional)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> wf_data{0.4691, 0.3185, -0.2227 ,0.4423, -0.0609, -0.2803, 0.1744, 0.3146, 0.4049, -0.3973, -0.0890, -0.1636};
std::vector<float> wr_data{-0.0296, -0.1341, 0.1761, -0.2325, -0.0717, 0.1852, 0.2720, 0.1471, -0.1097, 0.3363, -0.0587, -0.2302};
std::vector<float> rf_data{-0.0456, 0.1061, 0.1574, -0.4928, -0.4300, -0.1909, -0.0225, -0.2668, 0.1840, -0.4453, -0.4896, 0.1302, -0.0929, 0.3545, -0.4981, 0.0616};
std::vector<float> rr_data{ 0.2528, -0.2333, 0.3973, 0.1593, -0.0388, 0.1702, 0.3829, -0.0712, -0.1668, 0.3074, -0.2854, 0.4049, -0.3737, -0.1051, 0.4482, -0.2841};
std::vector<float> biasf_data{-0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990};
std::vector<float> biasr_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
float clip = 0.0f;
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w_data = wf_data;
w_data.insert(w_data.end(), wr_data.begin(), wr_data.end());
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r_data = rf_data;
r_data.insert(r_data.end(), rr_data.begin(), rr_data.end());
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias_data = biasf_data;
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r,
bias,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, 0.13104209, -0.18736027,
-0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635, 0.14803654,
0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477,
-0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031, -0.20639211, 0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w_data = wf_data;
w_data.insert(w_data.end(), wr_data.begin(), wr_data.end());
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r_data = rf_data;
r_data.insert(r_data.end(), rr_data.begin(), rr_data.end());
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias_data = biasf_data;
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto out_hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r,
bias,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477,
-0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635, 0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
*/
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment