"test/srt/git@developer.sourcefind.cn:change/sglang.git" did not exist on "3f41b48c4079d0adaf17ce7adf308c6a0d947ad0"
Commit a87890be authored by Shucai Xiao's avatar Shucai Xiao
Browse files

commit gru changes.

parent 6b1e5e63
...@@ -1173,6 +1173,20 @@ struct rnn ...@@ -1173,6 +1173,20 @@ struct rnn
} }
}; };
struct rnn_last_output
{
std::string name() const { return "rnn_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct gru struct gru
{ {
enum gru_direction_t enum gru_direction_t
...@@ -1217,20 +1231,6 @@ struct gru ...@@ -1217,20 +1231,6 @@ struct gru
} }
}; };
struct rnn_last_output
{
std::string name() const { return "rnn_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct gru_last_output struct gru_last_output
{ {
std::string name() const { return "gru_last_output"; } std::string name() const { return "gru_last_output"; }
......
...@@ -832,8 +832,8 @@ struct onnx_parser ...@@ -832,8 +832,8 @@ struct onnx_parser
{ {
// 4 activation functions are used in the bidirectional // 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we // scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provides, // use the algorithm that: if 1 actv function is provided,
// repeat 1 four times. If 2 actv functins are provides, // repeat 1 four times. If 2 actv functins are provided,
// assume forward and reverse use the same pair of actv // assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided, // functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the // assume the 3rd one is repeated once and used by the
...@@ -869,12 +869,11 @@ struct onnx_parser ...@@ -869,12 +869,11 @@ struct onnx_parser
} }
}); });
std::vector<operation> vec_actv_funcs; std::vector<operation> vec_actv_funcs(vec_names.size());
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) { std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
vec_actv_funcs.push_back(map_actv_funcs[name]); return map_actv_funcs[name];
}); });
// To be added later
float clip = 0.0; float clip = 0.0;
if(contains(attributes, "clip")) if(contains(attributes, "clip"))
{ {
...@@ -887,18 +886,22 @@ struct onnx_parser ...@@ -887,18 +886,22 @@ struct onnx_parser
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>(); linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
} }
std::vector<instruction_ref> result; // append undefined opeator to make 6 arguments
if (args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = prog.add_instruction( auto hidden_states = prog.add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset}, op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args)); std::move(args));
result.push_back(hidden_states);
// second output for last gru output // second output for last gru output
auto last_output = prog.add_instruction(op::gru_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::gru_last_output{}, hidden_states);
result.push_back(last_output);
return result; return {hidden_states, last_output};
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void rewrite_gru::apply(program& prog) const void rewrite_gru::apply(program& prog) const
{ {
instruction_ref last_output = prog.end(); std::unordered_map<instruction_ref, instruction_ref> map_last_output;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
if(ins->name() == "gru") if(ins->name() == "gru")
...@@ -22,9 +22,9 @@ void rewrite_gru::apply(program& prog) const ...@@ -22,9 +22,9 @@ void rewrite_gru::apply(program& prog) const
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2]; std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batchs = seq_shape.lens()[1]; std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batchs, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<char> data(ih_shape.bytes(), 0); std::vector<char> data(ih_shape.bytes(), 0);
auto gru_op = any_cast<op::gru>(ins->get_operator()); auto gru_op = any_cast<op::gru>(ins->get_operator());
...@@ -42,7 +42,7 @@ void rewrite_gru::apply(program& prog) const ...@@ -42,7 +42,7 @@ void rewrite_gru::apply(program& prog) const
// bias // bias
instruction_ref bias_forward, bias_reverse; instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end(); bias_forward = bias_reverse = prog.end();
if(args.size() >= 4) if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
...@@ -50,12 +50,10 @@ void rewrite_gru::apply(program& prog) const ...@@ -50,12 +50,10 @@ void rewrite_gru::apply(program& prog) const
// intial hidden state // intial hidden state
instruction_ref ih_forward, ih_reverse; instruction_ref ih_forward, ih_reverse;
if(args.size() == 6 || if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
auto arg_ih = (args.size() == 6) ? args[5] : args[4]; ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih); ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
} }
else else
{ {
...@@ -87,7 +85,7 @@ void rewrite_gru::apply(program& prog) const ...@@ -87,7 +85,7 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(2), gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3)); gru_op.actv_funcs.at(3));
last_output = auto last_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]); prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction // add the dimension of num_direction
...@@ -95,7 +93,8 @@ void rewrite_gru::apply(program& prog) const ...@@ -95,7 +93,8 @@ void rewrite_gru::apply(program& prog) const
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]); ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output // concat the forward and reverse output
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); auto hidden_state = prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
map_last_output[hidden_state] = last_output;
} }
else else
{ {
...@@ -106,17 +105,16 @@ void rewrite_gru::apply(program& prog) const ...@@ -106,17 +105,16 @@ void rewrite_gru::apply(program& prog) const
// bias // bias
instruction_ref bias = prog.end(); instruction_ref bias = prog.end();
if(args.size() >= 4) if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{ {
bias = args[3]; bias = args[3];
} }
// intial hidden state // intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() == 6 || if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
ih = args.size() == 6 ? args[5] : args[4]; ih = args[5];
} }
else else
{ {
...@@ -135,10 +133,11 @@ void rewrite_gru::apply(program& prog) const ...@@ -135,10 +133,11 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)); gru_op.actv_funcs.at(1));
last_output = ret[1]; auto last_output = ret[1];
// add the dimension of num_direction // add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]); auto hidden_state = prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
map_last_output[hidden_state] = last_output;
} }
} }
...@@ -148,11 +147,10 @@ void rewrite_gru::apply(program& prog) const ...@@ -148,11 +147,10 @@ void rewrite_gru::apply(program& prog) const
// so we can just use it as the output here // so we can just use it as the output here
if(ins->name() == "gru_last_output") if(ins->name() == "gru_last_output")
{ {
if(last_output != prog.end()) auto inputs = ins->inputs();
{ assert(inputs.size() == 1);
prog.replace_instruction(ins, op::identity{}, last_output); assert(map_last_output.count(inputs[0]) > 0);
last_output = prog.end(); prog.replace_instruction(ins, map_last_output[inputs[0]]);
}
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment