"...composable_kernel_onnxruntime.git" did not exist on "fee92fb636a7f1a6144a5358f22985502529160b"
Commit 0427ce2b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format.

parent a2ea4ecd
...@@ -1062,28 +1062,28 @@ struct rnn ...@@ -1062,28 +1062,28 @@ struct rnn
bidirectional, bidirectional,
}; };
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
operation actv_func = tanh{}; operation actv_func = tanh{};
rnn_direction_t direction = forward; rnn_direction_t direction = forward;
float clip = 0.0f; float clip = 0.0f;
std::string name() const { return "rnn"; } std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto in_dims = inputs[0].lens(); auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[1].lens(); auto hidden_dims = inputs[1].lens();
if (hidden_size != hidden_dims[1]) if(hidden_size != hidden_dims[1])
{ {
MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input"); MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if (direction == rnn_direction_t::bidirectional) if(direction == rnn_direction_t::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
if (num_directions != hidden_dims[0]) if(num_directions != hidden_dims[0])
{ {
MIGRAPHX_THROW("RNN: num_direction does not match the direction attribute"); MIGRAPHX_THROW("RNN: num_direction does not match the direction attribute");
} }
...@@ -1096,7 +1096,6 @@ struct rnn ...@@ -1096,7 +1096,6 @@ struct rnn
} }
}; };
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -661,7 +661,7 @@ struct onnx_parser ...@@ -661,7 +661,7 @@ struct onnx_parser
actv_func_map.insert(std::make_pair("relu", op::relu{})); actv_func_map.insert(std::make_pair("relu", op::relu{}));
actv_func_map.insert(std::make_pair("sigmoid", op::sigmoid{})); actv_func_map.insert(std::make_pair("sigmoid", op::sigmoid{}));
if (actv_func_map.count(activation_func) == 0) if(actv_func_map.count(activation_func) == 0)
{ {
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported"); MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
} }
...@@ -690,7 +690,8 @@ struct onnx_parser ...@@ -690,7 +690,8 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(attributes.at("clip")).at<float>();
} }
return prog.add_instruction(op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args)); return prog.add_instruction(
op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args));
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
......
...@@ -18,22 +18,21 @@ void rewrite_rnn::apply(program& prog) const ...@@ -18,22 +18,21 @@ void rewrite_rnn::apply(program& prog) const
} }
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs, // could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so // the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs // we need to process up to 5 inputs
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
shape wgt_shape = args[1]->get_shape(); shape wgt_shape = args[1]->get_shape();
std::size_t hidden_size = wgt_shape.lens()[1]; std::size_t hidden_size = wgt_shape.lens()[1];
std::size_t batch_size = seq_shape.lens()[1]; std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape s{type, {batch_size, hidden_size}}; migraphx::shape s{type, {batch_size, hidden_size}};
std::vector<char> data(s.bytes(), 0); std::vector<char> data(s.bytes(), 0);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction; op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::rnn_direction_t::bidirectional) if(dicrt == op::rnn::rnn_direction_t::bidirectional)
{ {
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
// process input weight matrix // process input weight matrix
...@@ -65,17 +64,19 @@ void rewrite_rnn::apply(program& prog) const ...@@ -65,17 +64,19 @@ void rewrite_rnn::apply(program& prog) const
long h_size = static_cast<long>(hidden_size); long h_size = static_cast<long>(hidden_size);
auto b_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); auto b_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
b_forward = prog.insert_instruction(ins, op::squeeze{{0}}, b_forward); b_forward = prog.insert_instruction(ins, op::squeeze{{0}}, b_forward);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward); auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward);
auto rbf = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward); auto rbf =
auto bf = prog.insert_instruction(ins, op::add{}, wbf, rbf); prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward);
auto bf = prog.insert_instruction(ins, op::add{}, wbf, rbf);
bias_forward = prog.insert_instruction(ins, op::broadcast{1, s}, bf); bias_forward = prog.insert_instruction(ins, op::broadcast{1, s}, bf);
// backward // backward
auto b_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); auto b_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
b_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, b_reverse); b_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, b_reverse);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse); auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse); auto rbr =
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr); prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
bias_reverse = prog.insert_instruction(ins, op::broadcast{1, s}, br); bias_reverse = prog.insert_instruction(ins, op::broadcast{1, s}, br);
} }
...@@ -144,9 +145,9 @@ void rewrite_rnn::apply(program& prog) const ...@@ -144,9 +145,9 @@ void rewrite_rnn::apply(program& prog) const
long h_size = static_cast<long>(hidden_size); long h_size = static_cast<long>(hidden_size);
auto bwr = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]); auto bwr = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, bwr); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, bwr);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, bwr); auto rb = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, bwr);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb); auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, s}, b); bias = prog.insert_instruction(ins, op::broadcast{1, s}, b);
} }
// process intial hidden state // process intial hidden state
...@@ -159,7 +160,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -159,7 +160,8 @@ void rewrite_rnn::apply(program& prog) const
{ {
ih = prog.add_literal(migraphx::literal{s, data}); ih = prog.add_literal(migraphx::literal{s, data});
} }
auto ret = rnn_oper(is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func); auto ret = rnn_oper(
is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func);
// add the dimension of num_direction // add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]); prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
...@@ -168,14 +170,14 @@ void rewrite_rnn::apply(program& prog) const ...@@ -168,14 +170,14 @@ void rewrite_rnn::apply(program& prog) const
} }
std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward, std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
instruction_ref wx, instruction_ref wx,
instruction_ref wh, instruction_ref wh,
instruction_ref ih, instruction_ref ih,
instruction_ref bias, instruction_ref bias,
operation& actv_func) const operation& actv_func) const
{ {
instruction_ref hidden_out, final_out; instruction_ref hidden_out, final_out;
migraphx::shape input_shape = input->get_shape(); migraphx::shape input_shape = input->get_shape();
...@@ -183,8 +185,8 @@ std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward, ...@@ -183,8 +185,8 @@ std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward,
long seq_index = is_forward ? 0 : seq_len - 1; long seq_index = is_forward ? 0 : seq_len - 1;
for(std::size_t i = 0; i < seq_len; i++) for(std::size_t i = 0; i < seq_len; i++)
{ {
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto x_w = prog.insert_instruction(ins, op::dot{}, xt, wx); auto x_w = prog.insert_instruction(ins, op::dot{}, xt, wx);
auto h_r = prog.insert_instruction(ins, op::dot{}, ih, wh); auto h_r = prog.insert_instruction(ins, op::dot{}, ih, wh);
auto x_h = prog.insert_instruction(ins, op::add{}, x_w, h_r); auto x_h = prog.insert_instruction(ins, op::add{}, x_w, h_r);
...@@ -208,14 +210,14 @@ std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward, ...@@ -208,14 +210,14 @@ std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward,
if(is_forward) if(is_forward)
{ {
hidden_out = (seq_index == 0) hidden_out = (seq_index == 0)
? output ? output
: prog.insert_instruction(ins, op::concat{0}, hidden_out, output); : prog.insert_instruction(ins, op::concat{0}, hidden_out, output);
} }
else else
{ {
hidden_out = (seq_index == seq_len - 1) hidden_out = (seq_index == seq_len - 1)
? output ? output
: prog.insert_instruction(ins, op::concat{0}, output, hidden_out); : prog.insert_instruction(ins, op::concat{0}, output, hidden_out);
} }
seq_index = is_forward ? (seq_index + 1) : (seq_index - 1); seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment