Commit 60b3056e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge rnn operator changes.

parents 250a0243 128b0b65
......@@ -1075,7 +1075,7 @@ struct rnn
};
std::size_t hidden_size = 1;
operation actv_func{tanh{}};
std::vector<operation> actv_funcs{tanh{}};
rnn_direction_t direction = forward;
float clip = 0.0f;
......
......@@ -663,17 +663,6 @@ struct onnx_parser
MIGRAPHX_THROW("RNN: hidden size attribute missing");
}
std::string activation_func = {"tanh"};
if(contains(attributes, "activations"))
{
activation_func = attributes.at("activations").strings(0);
}
if(map_actv_funcs.count(activation_func) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
......@@ -691,6 +680,37 @@ struct onnx_parser
dirct = op::rnn::reverse;
}
std::vector<std::string> vec_names{"tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
for_each(names.begin(), names.end(), [&](auto& fn) { vec_names.push_back(fn); });
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) {
if(map_actv_funcs.count(fn) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + fn + " not supported");
}
});
// bidirectional should have two activation functions
// if only one actv function is provides, we use it in both
// forward and reverse direction
if(dirct == op::rnn::bidirectional)
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
std::vector<operation> vec_actv_funcs;
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) {
vec_actv_funcs.push_back(map_actv_funcs[fn]);
});
// To be added later
float clip = 0.0;
if(contains(attributes, "clip"))
......@@ -698,8 +718,8 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>();
}
return prog.add_instruction(
op::rnn{hidden_size, map_actv_funcs[activation_func], dirct, clip}, std::move(args));
return prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args));
}
instruction_ref
......
......@@ -106,7 +106,7 @@ void rewrite_rnn::apply(program& prog) const
trans_hw_forward,
ih_forward,
bias_forward,
rnn_op.actv_func);
rnn_op.actv_funcs.at(0));
auto ret_reverse = rnn_oper(false,
prog,
ins,
......@@ -115,7 +115,7 @@ void rewrite_rnn::apply(program& prog) const
trans_hw_reverse,
ih_reverse,
bias_reverse,
rnn_op.actv_func);
rnn_op.actv_funcs.at(1));
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
......@@ -128,7 +128,7 @@ void rewrite_rnn::apply(program& prog) const
}
else
{
bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward) ? true : false;
bool is_forward = (dicrt == op::rnn::forward) ? true : false;
std::vector<int64_t> perm{1, 0};
// process input weight matrix
auto sxw = prog.insert_instruction(ins, op::squeeze{{0}}, args[1]);
......@@ -160,8 +160,15 @@ void rewrite_rnn::apply(program& prog) const
{
ih = prog.add_literal(migraphx::literal{s, data});
}
auto ret = rnn_oper(
is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func);
auto ret = rnn_oper(is_forward,
prog,
ins,
args[0],
trans_xw,
trans_hw,
ih,
bias,
rnn_op.actv_funcs.at(0));
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
......
......@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
float sigmoid(float x) { return 1 / (1 + expf(-x)); }
......@@ -1346,4 +1347,101 @@ TEST_CASE(min_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
/*
TEST_CASE(rnn_test)
{
{
migraphx::program p;
size_t hidden_size = 8;
size_t input_size = 6;
size_t batch_size = 2;
size_t seq_len = 5;
migraphx::shape hidden_shape{migraphx::shape::float_type, {1, batch_size, hidden_size}};
migraphx::shape input_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input(input_shape.elements(), 0.0);
input[0] = input[1] = 1.0;
std::vector<float> init_hidden(hidden_shape.elements(), 0.0);
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
m["input"] = migraphx::argument(input_shape, input.data());
auto resarg = p.eval(m);
std::vector<float> res;
resarg.visit([&](auto output) { res.assign(output.begin(), output.end()); } );
std::vector<float> res_gold{
0.596363, -0.274248, 0.714484, 0.282515, 0.0938349,
0.185406, 0.283227, -0.482086, 0.265265, -0.523217,
0.50433, 0.400934, -0.34513, 0.114924, 0.0392658,
-0.0976029, 0.364322, -0.567117, 0.538775, 0.314859,
-0.478676, 0.51778, -0.286718, -0.0478341, 0.339601,
-0.380976, 0.628219, 0.222791, -0.271949, 0.490674,
-0.234456, -0.224984, 0.456527, -0.454559, 0.546034,
-0.0389027, -0.307475, 0.561003, -0.245673, -0.0776644,
0.447162, -0.52013, 0.511913, 0.0324621, -0.380515,
0.500777, -0.225695, -0.0193589, 0.458955, -0.531746,
0.448536, -0.087655, -0.430165, 0.551379, -0.161603,
-0.0165391, 0.447551, -0.491717, 0.484796, -0.0699652,
-0.3941, 0.561967, -0.168543, -0.0661258, 0.465925,
-0.499277, 0.45216, -0.103005, -0.392837, 0.584424,
-0.189044, -0.0388068, 0.468369, -0.512927, 0.449144,
-0.0900977, -0.400401, 0.573534, -0.19617, -0.0208253};
EXPECT(migraphx::verify_range(res, res_gold));
}
{
migraphx::program p;
size_t hidden_size = 6;
size_t input_size = 4;
size_t batch_size = 2;
size_t seq_len = 5;
migraphx::shape hidden_shape{migraphx::shape::float_type, {6, batch_size, hidden_size}};
migraphx::shape input_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input(input_shape.elements(), 0.0);
input[0] = input[1] = 1.0;
std::vector<float> init_hidden(hidden_shape.elements(), 0.0);
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
m["input"] = migraphx::argument(input_shape, &input[0]);
auto resarg = p.eval(m);
std::vector<float> res;
resarg.visit([&](auto output) { res.assign(output.begin(), output.end()); } );
std::vector<float> res_gold{
-0.0890872, -0.0558751, 0.185233, 0.452857, 0.104082,
0.432953, 0.274236, 0.186055, -0.367716, 0.266761,
-0.28489, 0.498758, 0.0140574, -0.122377, 0.278067,
0.469699, 0.216743, 0.258926, 0.269785, 0.328379,
-0.576081, 0.11672, -0.452062, 0.603549, 0.472625,
0.120929, 0.350331, 0.502138, 0.103585, 0.128486,
0.0210318, 0.338759, -0.654448, 0.37656, -0.359715,
0.424365, 0.449677, 0.130903, 0.354359, 0.59317,
0.189543, 0.201865, 0.126288, 0.31099, -0.681538,
0.275407, -0.406133, 0.450767, 0.305638, 0.14942,
0.309857, 0.722745, 0.361199, -0.00963601, 0.397046,
0.264047, -0.539317, 0.0690505, -0.321901, 0.566638,
0.406511, 0.231472, 0.320225, 0.737927, 0.372938,
0.00762333, 0.349881, 0.280791, -0.541838, 0.128319,
-0.266702, 0.536205, 0.509004, 0.361068, 0.42431,
0.767474, 0.368881, 0.0753035, 0.141155, 0.219692,
-0.643801, 0.281643, -0.330984, 0.397033, 0.494424,
0.38013, 0.434627, 0.795404, 0.391589, 0.0102068,
0.166358, 0.226248, -0.608175, 0.302622, -0.349646,
0.375506, 0.546918, 0.22908, 0.40025, 0.806049,
0.424462, -0.0352604, 0.528827, -0.0372434, -0.573789,
-0.0541837, -0.194983, 0.552972, 0.553695, 0.263657,
0.432448, 0.815763, 0.412716, -0.0389366, 0.52391,
-0.0256845, -0.577296, -0.0570545, -0.219738, 0.561644};
EXPECT(migraphx::verify_range(res, res_gold));
}
}
*/
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment