"...resnet50_tensorflow.git" did not exist on "55d41fd67d7c08e9a88cc38f41f2b32089bde134"
Commit 20f89fcc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 0fe4c56b
...@@ -729,12 +729,12 @@ struct onnx_parser ...@@ -729,12 +729,12 @@ struct onnx_parser
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
// first output for the concatenation of hidden states // first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args)); std::move(args));
result.push_back(hidden_states); result.push_back(hidden_states);
// second out for the last hidden state // second out for the last hidden state
//auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states); // auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
//result.push_back(last_output); // result.push_back(last_output);
return result; return result;
} }
......
...@@ -16,9 +16,9 @@ void rewrite_rnn::apply(program& prog) const ...@@ -16,9 +16,9 @@ void rewrite_rnn::apply(program& prog) const
// rewrite rnn operator // rewrite rnn operator
if(ins->name() == "rnn") if(ins->name() == "rnn")
{ {
// could be 3 to 6 inputs, but the 5th input is undefined in // could be 3 to 6 inputs, but the 5th input is undefined in
// pytorch exported onnx, and it is ignored by protobuf. So // pytorch exported onnx, and it is ignored by protobuf. So
// for input arguments 5 and 6, we need to check the shape, // for input arguments 5 and 6, we need to check the shape,
// then based on the shape to judge the specific input info // then based on the shape to judge the specific input info
auto args = ins->inputs(); auto args = ins->inputs();
...@@ -34,12 +34,12 @@ void rewrite_rnn::apply(program& prog) const ...@@ -34,12 +34,12 @@ void rewrite_rnn::apply(program& prog) const
if(dicrt == op::rnn::rnn_direction_t::bidirectional) if(dicrt == op::rnn::rnn_direction_t::bidirectional)
{ {
// input weight matrix // input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]); auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// hidden state weight matrix // hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]); auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias // process bias
instruction_ref bias_forward, bias_reverse; instruction_ref bias_forward, bias_reverse;
...@@ -53,11 +53,12 @@ void rewrite_rnn::apply(program& prog) const ...@@ -53,11 +53,12 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state, it could be the 6th argument // process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored) // or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward, ih_reverse; instruction_ref ih_forward, ih_reverse;
if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3)) if(args.size() == 6 ||
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
auto arg_ih = (args.size() == 6) ? args[5] : args[4]; auto arg_ih = (args.size() == 6) ? args[5] : args[4];
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih); ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih); ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
} }
else else
{ {
...@@ -84,7 +85,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -84,7 +85,8 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse, ih_reverse,
rnn_op.actv_funcs.at(1)); rnn_op.actv_funcs.at(1));
last_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]); last_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction // add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]); ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
...@@ -111,7 +113,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -111,7 +113,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state // process intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3)) if(args.size() == 6 ||
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
ih = (args.size() == 6) ? args[5] : args[4]; ih = (args.size() == 6) ? args[5] : args[4];
} }
...@@ -120,15 +123,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -120,15 +123,8 @@ void rewrite_rnn::apply(program& prog) const
ih = prog.add_literal(migraphx::literal{ih_shape, data}); ih = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret = rnn_cell(is_forward, auto ret = rnn_cell(
prog, is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
ins,
args[0],
w,
r,
bias,
ih,
rnn_op.actv_funcs.at(0));
last_output = ret[1]; last_output = ret[1];
// add the dimension of num_direction // add the dimension of num_direction
...@@ -136,11 +132,11 @@ void rewrite_rnn::apply(program& prog) const ...@@ -136,11 +132,11 @@ void rewrite_rnn::apply(program& prog) const
} }
} }
// rewrite the rnn_last_output operator that right after the rnn // rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on the input to get // operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator, // the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here // so we can just use it as the output here
//if (ins->name() == "rnn_last_output") // if (ins->name() == "rnn_last_output")
//{ //{
// // if rnn operator is executed, the last_output != prog.end() // // if rnn operator is executed, the last_output != prog.end()
// if (last_output != prog.end()) // if (last_output != prog.end())
...@@ -164,31 +160,30 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -164,31 +160,30 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
{ {
// squeeze and transpose w // squeeze and transpose w
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tran_sw = prog.insert_instruction(sw, op::transpose{perm}, sw); auto tran_sw = prog.insert_instruction(sw, op::transpose{perm}, sw);
// squeeze and transpose r // squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr); auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias // bias
if (bias != prog.end()) if(bias != prog.end())
{ {
long hs = r->get_shape().lens()[2]; long hs = r->get_shape().lens()[2];
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, sbias); auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb); bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
} }
instruction_ref hidden_out, last_out; instruction_ref hidden_out, last_out;
std::size_t seq_len = input->get_shape().lens()[0]; std::size_t seq_len = input->get_shape().lens()[0];
long seq_index = is_forward ? 0 : seq_len - 1; long seq_index = is_forward ? 0 : seq_len - 1;
for(std::size_t i = 0; i < seq_len; i++) for(std::size_t i = 0; i < seq_len; i++)
{ {
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
...@@ -207,7 +202,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -207,7 +202,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
} }
// apply activation function // apply activation function
ht = prog.insert_instruction(ins, actv_func, ht); ht = prog.insert_instruction(ins, actv_func, ht);
sih = ht; sih = ht;
// add the dimension of sequence length // add the dimension of sequence length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment