Commit 1fbe8c48 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge rnn improvement.

parents 60b3056e 62044b86
...@@ -640,7 +640,7 @@ struct as_shape ...@@ -640,7 +640,7 @@ struct as_shape
struct gather struct gather
{ {
mutable int axis = 0; int axis = 0;
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -654,43 +654,44 @@ struct gather ...@@ -654,43 +654,44 @@ struct gather
} }
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
if(axis < 0) int axis_index = (axis < 0) ? (n_dim + axis) : axis;
{
axis += n_dim;
}
auto type = inputs[0].type(); auto type = inputs[0].type();
lens[axis] = inputs[1].elements(); lens[axis_index] = inputs[1].elements();
return {type, lens}; return {type, lens};
} }
template <class T> template <class T>
void compute_index(const T& out_idx, void compute_index(const T& out_idx,
const int axis_index,
const std::vector<std::size_t>& vec_indices, const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim, const std::size_t max_dim,
T& in_idx) const T& in_idx) const
{ {
in_idx = out_idx; in_idx = out_idx;
std::size_t idx = vec_indices.at(out_idx[axis]); std::size_t idx = vec_indices.at(out_idx[axis_index]);
if(idx >= max_dim) if(idx >= max_dim)
{ {
MIGRAPHX_THROW("Gather: indices are out of range in input tensor"); MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
} }
in_idx[axis] = idx; in_idx[axis_index] = idx;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;
// max dimension in axis // max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis]; std::size_t max_dim = args[0].get_shape().lens()[axis_index];
std::vector<std::size_t> vec_indices; std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); }); args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx; std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, vec_indices, max_dim, in_idx); this->compute_index(idx, axis_index, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end()); output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
}); });
}); });
...@@ -1152,6 +1153,20 @@ struct gru ...@@ -1152,6 +1153,20 @@ struct gru
} }
}; };
struct rnn_last_output
{
std::string name() const { return "rnn_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -21,7 +21,7 @@ struct rewrite_rnn ...@@ -21,7 +21,7 @@ struct rewrite_rnn
void apply(program& prog) const; void apply(program& prog) const;
private: private:
std::vector<instruction_ref> rnn_oper(bool is_forward, std::vector<instruction_ref> rnn_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
......
...@@ -24,7 +24,8 @@ struct onnx_parser ...@@ -24,7 +24,8 @@ struct onnx_parser
{ {
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>; using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
...@@ -103,6 +104,15 @@ struct onnx_parser ...@@ -103,6 +104,15 @@ struct onnx_parser
template <class F> template <class F>
void add_op(std::string name, F f) void add_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
});
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{ {
ops.emplace(name, f); ops.emplace(name, f);
} }
...@@ -110,7 +120,7 @@ struct onnx_parser ...@@ -110,7 +120,7 @@ struct onnx_parser
template <class F> template <class F>
void add_mem_op(std::string name, F f) void add_mem_op(std::string name, F f)
{ {
ops.emplace(name, [=](auto&&... xs) { add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); });
} }
...@@ -118,17 +128,15 @@ struct onnx_parser ...@@ -118,17 +128,15 @@ struct onnx_parser
template <class T> template <class T>
void add_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast")) if(contains(attributes, "broadcast") and contains(attributes, "axis"))
{ {
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>(); uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = (contains(attributes, "axis")) uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = auto l =
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
...@@ -189,7 +197,7 @@ struct onnx_parser ...@@ -189,7 +197,7 @@ struct onnx_parser
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
...@@ -197,7 +205,7 @@ struct onnx_parser ...@@ -197,7 +205,7 @@ struct onnx_parser
template <class T> template <class T>
void add_variadic_op(std::string name, T x) void add_variadic_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()), return std::accumulate(std::next(args.begin()),
args.end(), args.end(),
args.front(), args.front(),
...@@ -648,7 +656,7 @@ struct onnx_parser ...@@ -648,7 +656,7 @@ struct onnx_parser
} }
} }
instruction_ref std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
migraphx::shape input_shape = args[0]->get_shape(); migraphx::shape input_shape = args[0]->get_shape();
...@@ -718,8 +726,17 @@ struct onnx_parser ...@@ -718,8 +726,17 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(attributes.at("clip")).at<float>();
} }
return prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, std::vector<instruction_ref> result;
std::move(args)); // first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args));
result.push_back(hidden_states);
// second out for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
result.push_back(last_output);
return result;
} }
instruction_ref instruction_ref
...@@ -842,7 +859,7 @@ struct onnx_parser ...@@ -842,7 +859,7 @@ struct onnx_parser
} }
else else
{ {
throw std::runtime_error("Failed reading"); MIGRAPHX_THROW("Failed reading onnx file.");
} }
} }
...@@ -872,7 +889,7 @@ struct onnx_parser ...@@ -872,7 +889,7 @@ struct onnx_parser
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
this->parse_node(get_name(p.second)); this->parse_node(p.first);
} }
} }
...@@ -898,23 +915,37 @@ struct onnx_parser ...@@ -898,23 +915,37 @@ struct onnx_parser
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
auto&& iname = get_name(nodes.at(input)); assert(name != input);
assert(name != iname); this->parse_node(input);
this->parse_node(iname); args.push_back(instructions.at(input));
args.push_back(instructions.at(iname));
} }
else else
{ {
args.push_back(instructions.at(input)); args.push_back(instructions.at(input));
} }
} }
std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
{ {
instructions[name] = prog.add_instruction(unknown{node.op_type()}, args); result.push_back(prog.add_instruction(unknown{node.op_type()}, args));
} }
else else
{ {
instructions[name] = ops[node.op_type()](get_attributes(node), args); result = ops[node.op_type()](get_attributes(node), args);
}
// Even no output nodes produce output in migraphx
if(node.output().empty() and result.size() == 1)
{
instructions[name] = result.front();
}
else
{
assert(node.output().size() >= result.size());
std::transform(result.begin(),
result.end(),
node.output().begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(y, x); });
} }
} }
} }
...@@ -929,25 +960,24 @@ struct onnx_parser ...@@ -929,25 +960,24 @@ struct onnx_parser
return result; return result;
} }
static std::string get_name(const onnx::NodeProto& node)
{
if(node.name().empty())
{
std::string generated = "migraphx_unnamed_node";
return std::accumulate(node.output().begin(),
node.output().end(),
generated,
[](auto x, auto y) { return x + "_" + y; });
}
return node.name();
}
static node_map get_nodes(const onnx::GraphProto& graph) static node_map get_nodes(const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, onnx::NodeProto> result; std::unordered_map<std::string, onnx::NodeProto> result;
std::size_t n = 0;
for(auto&& node : graph.node()) for(auto&& node : graph.node())
{ {
result[get_name(node)] = node; if(node.output().empty())
{
if(node.name().empty())
{
result["migraphx_unamed_node_" + std::to_string(n)] = node;
n++;
}
else
{
result[node.name()] = node;
}
}
for(auto&& output : node.output()) for(auto&& output : node.output())
{ {
result[output] = node; result[output] = node;
......
...@@ -10,228 +10,222 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,228 +10,222 @@ inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const void rewrite_rnn::apply(program& prog) const
{ {
instruction_ref last_output = prog.end();
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
if(ins->name() != "rnn") // rewrite rnn operator
if(ins->name() == "rnn")
{ {
continue; // could be 3 to 6 inputs, but the 5th input is undefined in
} // pytorch exported onnx, and it is ignored by protobuf. So
// for input arguments 5 and 6, we need to check the shape,
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs, // then based on the shape to judge the specific input info
// the 5th one is undefined and ignored by protobuf. so auto args = ins->inputs();
// we need to process up to 5 inputs
auto args = ins->inputs(); shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
shape seq_shape = args[0]->get_shape(); std::size_t batch_size = seq_shape.lens()[1];
shape wgt_shape = args[1]->get_shape(); shape::type_t type = seq_shape.type();
std::size_t hidden_size = wgt_shape.lens()[1]; migraphx::shape ih_shape{type, {batch_size, hidden_size}};
std::size_t batch_size = seq_shape.lens()[1]; std::vector<char> data(ih_shape.bytes(), 0);
shape::type_t type = seq_shape.type();
migraphx::shape s{type, {batch_size, hidden_size}}; auto rnn_op = any_cast<op::rnn>(ins->get_operator());
std::vector<char> data(s.bytes(), 0); op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::rnn_direction_t::bidirectional)
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::rnn_direction_t::bidirectional)
{
std::vector<int64_t> perm{1, 0};
// process input weight matrix
// forward
auto xw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto sxw_forward = prog.insert_instruction(ins, op::squeeze{{0}}, xw_forward);
auto trans_xw_forward = prog.insert_instruction(ins, op::transpose{perm}, sxw_forward);
// reverse
auto xw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto sxw_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, xw_reverse);
auto trans_xw_reverse = prog.insert_instruction(ins, op::transpose{perm}, sxw_reverse);
// process hidden state weight matrix
auto hw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto shw_forward = prog.insert_instruction(ins, op::squeeze{{0}}, hw_forward);
auto trans_hw_forward = prog.insert_instruction(ins, op::transpose{perm}, shw_forward);
auto hw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
auto shw_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, hw_reverse);
auto trans_hw_reverse = prog.insert_instruction(ins, op::transpose{perm}, shw_reverse);
// process bias
instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end();
if(args.size() >= 4)
{
// forward
long h_size = static_cast<long>(hidden_size);
auto b_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
b_forward = prog.insert_instruction(ins, op::squeeze{{0}}, b_forward);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward);
auto rbf =
prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward);
auto bf = prog.insert_instruction(ins, op::add{}, wbf, rbf);
bias_forward = prog.insert_instruction(ins, op::broadcast{1, s}, bf);
// backward
auto b_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
b_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, b_reverse);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse);
auto rbr =
prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
bias_reverse = prog.insert_instruction(ins, op::broadcast{1, s}, br);
}
// process intial hidden state
instruction_ref ih_forward, ih_reverse;
if(args.size() >= 5)
{ {
// forward // input weight matrix
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[4]); auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
ih_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ih_forward); auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// reverse // hidden state weight matrix
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[4]); auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
ih_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ih_reverse); auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end();
if(args.size() >= 4)
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward, ih_reverse;
if(args.size() == 6 ||
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{
auto arg_ih = (args.size() == 6) ? args[5] : args[4];
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = rnn_cell(true,
prog,
ins,
args[0],
w_forward,
r_forward,
bias_forward,
ih_forward,
rnn_op.actv_funcs.at(0));
auto ret_reverse = rnn_cell(false,
prog,
ins,
args[0],
w_reverse,
r_reverse,
bias_reverse,
ih_reverse,
rnn_op.actv_funcs.at(1));
last_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
else else
{ {
ih_forward = prog.add_literal(migraphx::literal{s, data}); bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward) ? true : false;
ih_reverse = prog.add_literal(migraphx::literal{s, data}); // input weight matrix
auto w = args[1];
// hidden state weight matrix
auto r = args[2];
// process bias and initial hidden state
instruction_ref bias = prog.end();
if(args.size() >= 4)
{
bias = args[3];
}
// process intial hidden state
instruction_ref ih;
if(args.size() == 6 ||
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{
ih = (args.size() == 6) ? args[5] : args[4];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret = rnn_cell(
is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
last_output = ret[1];
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
} }
auto ret_forward = rnn_oper(true,
prog,
ins,
args[0],
trans_xw_forward,
trans_hw_forward,
ih_forward,
bias_forward,
rnn_op.actv_funcs.at(0));
auto ret_reverse = rnn_oper(false,
prog,
ins,
args[0],
trans_xw_reverse,
trans_hw_reverse,
ih_reverse,
bias_reverse,
rnn_op.actv_funcs.at(1));
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
// add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
else
{
bool is_forward = (dicrt == op::rnn::forward) ? true : false;
std::vector<int64_t> perm{1, 0};
// process input weight matrix
auto sxw = prog.insert_instruction(ins, op::squeeze{{0}}, args[1]);
auto trans_xw = prog.insert_instruction(ins, op::transpose{perm}, sxw);
// process hidden state weight matrix
auto shw = prog.insert_instruction(ins, op::squeeze{{0}}, args[2]);
auto trans_hw = prog.insert_instruction(ins, op::transpose{perm}, shw);
// process bias and initial hidden state
instruction_ref bias = prog.end();
if(args.size() >= 4)
{
long h_size = static_cast<long>(hidden_size);
auto bwr = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, bwr);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, bwr);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, s}, b);
}
// process intial hidden state // rewrite the rnn_last_output operator that right after the rnn
instruction_ref ih; // operator. Intuitively, we can do a slice on the input to get
if(args.size() >= 5) // the last output, but it is already existed in the rnn operator,
{ // so we can just use it as the output here
ih = prog.insert_instruction(ins, op::squeeze{{0}}, args[4]); if(ins->name() == "rnn_last_output")
} {
else // if rnn operator is executed, the last_output != prog.end()
if(last_output != prog.end())
{ {
ih = prog.add_literal(migraphx::literal{s, data}); prog.replace_instruction(ins, op::identity{}, last_output);
last_output = prog.end();
} }
auto ret = rnn_oper(is_forward,
prog,
ins,
args[0],
trans_xw,
trans_hw,
ih,
bias,
rnn_op.actv_funcs.at(0));
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
} }
} }
} }
std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward, std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
instruction_ref wx, instruction_ref w,
instruction_ref wh, instruction_ref r,
instruction_ref ih,
instruction_ref bias, instruction_ref bias,
instruction_ref ih,
operation& actv_func) const operation& actv_func) const
{ {
instruction_ref hidden_out, final_out; // squeeze and transpose w
migraphx::shape input_shape = input->get_shape(); std::vector<int64_t> perm{1, 0};
std::size_t seq_len = input_shape.lens()[0]; auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
long seq_index = is_forward ? 0 : seq_len - 1; auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
if(bias != prog.end())
{
long hs = r->get_shape().lens()[2];
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
}
instruction_ref hidden_out, last_out;
std::size_t seq_len = input->get_shape().lens()[0];
long seq_index = is_forward ? 0 : seq_len - 1;
for(std::size_t i = 0; i < seq_len; i++) for(std::size_t i = 0; i < seq_len; i++)
{ {
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto x_w = prog.insert_instruction(ins, op::dot{}, xt, wx); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto h_r = prog.insert_instruction(ins, op::dot{}, ih, wh); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto x_h = prog.insert_instruction(ins, op::add{}, x_w, h_r); auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref before_actv; instruction_ref ht;
if(bias != prog.end()) if(bias != prog.end())
{ {
before_actv = prog.insert_instruction(ins, op::add{}, x_h, bias); ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
} }
else else
{ {
before_actv = x_h; ht = xt_ht;
} }
// apply activation function // apply activation function
ih = prog.insert_instruction(ins, actv_func, before_actv); ht = prog.insert_instruction(ins, actv_func, ht);
sih = ht;
// add the dimension of sequence length // add the dimension of sequence length
auto output = prog.insert_instruction(ins, op::unsqueeze{{0}}, ih); last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, ht);
final_out = output;
if(is_forward) if(is_forward)
{ {
hidden_out = (seq_index == 0) hidden_out = (seq_index == 0)
? output ? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, output); : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
} }
else else
{ {
hidden_out = (seq_index == seq_len - 1) hidden_out = (seq_index == seq_len - 1)
? output ? last_out
: prog.insert_instruction(ins, op::concat{0}, output, hidden_out); : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
} }
seq_index = is_forward ? (seq_index + 1) : (seq_index - 1); seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
} }
std::vector<instruction_ref> out_args; std::vector<instruction_ref> out_args;
out_args.push_back(hidden_out); out_args.push_back(hidden_out);
out_args.push_back(final_out); out_args.push_back(last_out);
return out_args; return out_args;
} }
......
...@@ -14,8 +14,9 @@ namespace device { ...@@ -14,8 +14,9 @@ namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
std::size_t axis) int axis)
{ {
int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
...@@ -26,9 +27,9 @@ argument gather(hipStream_t stream, ...@@ -26,9 +27,9 @@ argument gather(hipStream_t stream,
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i); auto lens = desc_output.multi(i);
lens[axis] = indices_ptr[lens[axis]]; lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)]; outptr[i] = inptr[desc_input.linear(lens)];
}); });
}); });
}); });
......
...@@ -13,7 +13,7 @@ namespace device { ...@@ -13,7 +13,7 @@ namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
std::size_t axis); int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -950,6 +950,22 @@ struct test_gather ...@@ -950,6 +950,22 @@ struct test_gather
} }
}; };
struct test_gather_neg_axis
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
void manual_identity() void manual_identity()
{ {
migraphx::program p; migraphx::program p;
...@@ -1090,4 +1106,6 @@ int main() ...@@ -1090,4 +1106,6 @@ int main()
verify_program<test_conv_bn_relu_pooling>(); verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>(); verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>(); verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
} }
...@@ -224,12 +224,29 @@ TEST_CASE(gather) ...@@ -224,12 +224,29 @@ TEST_CASE(gather)
indices); indices);
} }
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 4; int axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices); throws_shape(migraphx::op::gather{axis}, input, indices);
} }
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -5;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment