Commit 0fe4c56b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

temp, code backup.

parent b6a9b597
...@@ -1109,6 +1109,20 @@ struct rnn ...@@ -1109,6 +1109,20 @@ struct rnn
} }
}; };
struct rnn_last_output
{
std::string name() const { return "rnn_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -21,7 +21,7 @@ struct rewrite_rnn ...@@ -21,7 +21,7 @@ struct rewrite_rnn
void apply(program& prog) const; void apply(program& prog) const;
private: private:
std::vector<instruction_ref> rnn_oper(bool is_forward, std::vector<instruction_ref> rnn_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
......
...@@ -655,7 +655,7 @@ struct onnx_parser ...@@ -655,7 +655,7 @@ struct onnx_parser
} }
} }
instruction_ref std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
migraphx::shape input_shape = args[0]->get_shape(); migraphx::shape input_shape = args[0]->get_shape();
...@@ -726,8 +726,17 @@ struct onnx_parser ...@@ -726,8 +726,17 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(attributes.at("clip")).at<float>();
} }
return prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, std::vector<instruction_ref> result;
// first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args)); std::move(args));
result.push_back(hidden_states);
// second out for the last hidden state
//auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
//result.push_back(last_output);
return result;
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
......
...@@ -10,114 +10,81 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,114 +10,81 @@ inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const void rewrite_rnn::apply(program& prog) const
{ {
instruction_ref last_output = prog.end();
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
if(ins->name() != "rnn") // rewrite rnn operator
if(ins->name() == "rnn")
{ {
continue; // could be 3 to 6 inputs, but the 5th input is undefined in
} // pytorch exported onnx, and it is ignored by protobuf. So
// for input arguments 5 and 6, we need to check the shape,
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs, // then based on the shape to judge the specific input info
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
shape wgt_shape = args[1]->get_shape(); std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t hidden_size = wgt_shape.lens()[1];
std::size_t batch_size = seq_shape.lens()[1]; std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape s{type, {batch_size, hidden_size}}; migraphx::shape ih_shape{type, {batch_size, hidden_size}};
std::vector<char> data(s.bytes(), 0); std::vector<char> data(ih_shape.bytes(), 0);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction; op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::rnn_direction_t::bidirectional) if(dicrt == op::rnn::rnn_direction_t::bidirectional)
{ {
std::vector<int64_t> perm{1, 0}; // input weight matrix
// process input weight matrix auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
// forward auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto xw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto sxw_forward = prog.insert_instruction(ins, op::squeeze{{0}}, xw_forward); // hidden state weight matrix
auto trans_xw_forward = prog.insert_instruction(ins, op::transpose{perm}, sxw_forward); auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// reverse
auto xw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto sxw_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, xw_reverse);
auto trans_xw_reverse = prog.insert_instruction(ins, op::transpose{perm}, sxw_reverse);
// process hidden state weight matrix
auto hw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto shw_forward = prog.insert_instruction(ins, op::squeeze{{0}}, hw_forward);
auto trans_hw_forward = prog.insert_instruction(ins, op::transpose{perm}, shw_forward);
auto hw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
auto shw_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, hw_reverse);
auto trans_hw_reverse = prog.insert_instruction(ins, op::transpose{perm}, shw_reverse);
// process bias // process bias
instruction_ref bias_forward, bias_reverse; instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end(); bias_forward = bias_reverse = prog.end();
if(args.size() >= 4) if(args.size() >= 4)
{ {
// forward bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
long h_size = static_cast<long>(hidden_size); bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
auto b_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
b_forward = prog.insert_instruction(ins, op::squeeze{{0}}, b_forward);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward);
auto rbf =
prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward);
auto bf = prog.insert_instruction(ins, op::add{}, wbf, rbf);
bias_forward = prog.insert_instruction(ins, op::broadcast{1, s}, bf);
// backward
auto b_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
b_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, b_reverse);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse);
auto rbr =
prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
bias_reverse = prog.insert_instruction(ins, op::broadcast{1, s}, br);
} }
// process intial hidden state // process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward, ih_reverse; instruction_ref ih_forward, ih_reverse;
if(args.size() >= 5) if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
// forward auto arg_ih = (args.size() == 6) ? args[5] : args[4];
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[4]); ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
ih_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ih_forward); ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
// reverse
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[4]);
ih_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ih_reverse);
} }
else else
{ {
ih_forward = prog.add_literal(migraphx::literal{s, data}); ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{s, data}); ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret_forward = rnn_oper(true, auto ret_forward = rnn_cell(true,
prog, prog,
ins, ins,
args[0], args[0],
trans_xw_forward, w_forward,
trans_hw_forward, r_forward,
ih_forward,
bias_forward, bias_forward,
ih_forward,
rnn_op.actv_funcs.at(0)); rnn_op.actv_funcs.at(0));
auto ret_reverse = rnn_oper(false, auto ret_reverse = rnn_cell(false,
prog, prog,
ins, ins,
args[0], args[0],
trans_xw_reverse, w_reverse,
trans_hw_reverse, r_reverse,
ih_reverse,
bias_reverse, bias_reverse,
ih_reverse,
rnn_op.actv_funcs.at(1)); rnn_op.actv_funcs.at(1));
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], last_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction // add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]); ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
...@@ -129,109 +96,141 @@ void rewrite_rnn::apply(program& prog) const ...@@ -129,109 +96,141 @@ void rewrite_rnn::apply(program& prog) const
else else
{ {
bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward) ? true : false; bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward) ? true : false;
std::vector<int64_t> perm{1, 0}; // input weight matrix
// process input weight matrix auto w = args[1];
auto sxw = prog.insert_instruction(ins, op::squeeze{{0}}, args[1]);
auto trans_xw = prog.insert_instruction(ins, op::transpose{perm}, sxw);
// process hidden state weight matrix // hidden state weight matrix
auto shw = prog.insert_instruction(ins, op::squeeze{{0}}, args[2]); auto r = args[2];
auto trans_hw = prog.insert_instruction(ins, op::transpose{perm}, shw);
// process bias and initial hidden state // process bias and initial hidden state
instruction_ref bias = prog.end(); instruction_ref bias = prog.end();
if(args.size() >= 4) if(args.size() >= 4)
{ {
long h_size = static_cast<long>(hidden_size); bias = args[3];
auto bwr = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, bwr);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, bwr);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, s}, b);
} }
// process intial hidden state // process intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() >= 5) if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
ih = prog.insert_instruction(ins, op::squeeze{{0}}, args[4]); ih = (args.size() == 6) ? args[5] : args[4];
} }
else else
{ {
ih = prog.add_literal(migraphx::literal{s, data}); ih = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret = rnn_oper(is_forward,
auto ret = rnn_cell(is_forward,
prog, prog,
ins, ins,
args[0], args[0],
trans_xw, w,
trans_hw, r,
ih,
bias, bias,
ih,
rnn_op.actv_funcs.at(0)); rnn_op.actv_funcs.at(0));
last_output = ret[1];
// add the dimension of num_direction // add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]); prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
} }
} }
// rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
//if (ins->name() == "rnn_last_output")
//{
// // if rnn operator is executed, the last_output != prog.end()
// if (last_output != prog.end())
// {
// prog.replace_instruction(ins, op::identity{}, last_output);
// last_output = prog.end();
// }
//}
}
} }
std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward, std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
instruction_ref wx, instruction_ref w,
instruction_ref wh, instruction_ref r,
instruction_ref ih,
instruction_ref bias, instruction_ref bias,
instruction_ref ih,
operation& actv_func) const operation& actv_func) const
{ {
instruction_ref hidden_out, final_out; // squeeze and transpose w
migraphx::shape input_shape = input->get_shape(); std::vector<int64_t> perm{1, 0};
std::size_t seq_len = input_shape.lens()[0]; auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tran_sw = prog.insert_instruction(sw, op::transpose{perm}, sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
if (bias != prog.end())
{
long hs = r->get_shape().lens()[2];
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb =
prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
}
instruction_ref hidden_out, last_out;
std::size_t seq_len = input->get_shape().lens()[0];
long seq_index = is_forward ? 0 : seq_len - 1; long seq_index = is_forward ? 0 : seq_len - 1;
for(std::size_t i = 0; i < seq_len; i++) for(std::size_t i = 0; i < seq_len; i++)
{ {
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto x_w = prog.insert_instruction(ins, op::dot{}, xt, wx); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto h_r = prog.insert_instruction(ins, op::dot{}, ih, wh); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto x_h = prog.insert_instruction(ins, op::add{}, x_w, h_r); auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref before_actv; instruction_ref ht;
if(bias != prog.end()) if(bias != prog.end())
{ {
before_actv = prog.insert_instruction(ins, op::add{}, x_h, bias); ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
} }
else else
{ {
before_actv = x_h; ht = xt_ht;
} }
// apply activation function // apply activation function
ih = prog.insert_instruction(ins, actv_func, before_actv); ht = prog.insert_instruction(ins, actv_func, ht);
sih = ht;
// add the dimension of sequence length // add the dimension of sequence length
auto output = prog.insert_instruction(ins, op::unsqueeze{{0}}, ih); last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, ht);
final_out = output;
if(is_forward) if(is_forward)
{ {
hidden_out = (seq_index == 0) hidden_out = (seq_index == 0)
? output ? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, output); : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
} }
else else
{ {
hidden_out = (seq_index == seq_len - 1) hidden_out = (seq_index == seq_len - 1)
? output ? last_out
: prog.insert_instruction(ins, op::concat{0}, output, hidden_out); : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
} }
seq_index = is_forward ? (seq_index + 1) : (seq_index - 1); seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
} }
std::vector<instruction_ref> out_args; std::vector<instruction_ref> out_args;
out_args.push_back(hidden_out); out_args.push_back(hidden_out);
out_args.push_back(final_out); out_args.push_back(last_out);
return out_args; return out_args;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment