Commit a7408288 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into tf_pb

parents fa983162 c1fec2c4
......@@ -27,6 +27,13 @@ void eliminate_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if(ins->name() == "reshape")
{
continue;
}
// Make a copy so we can modify it while we iterate
auto args = ins->inputs();
for(auto arg : ins->inputs())
......
......@@ -22,8 +22,8 @@ struct literal : raw_data<literal>
{
literal() {}
template <class U, class T = deduce<U>>
literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(shape::get_type<T>{})
template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}>
literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType)
{
static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
*(reinterpret_cast<T*>(buffer.get())) = x;
......
......@@ -757,43 +757,49 @@ struct gather
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type();
lens[axis_index] = inputs[1].elements();
return {type, lens};
}
auto type = inputs[0].type();
lens.erase(lens.begin() + axis_index);
if(!inputs[1].scalar())
{
auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
}
template <class T>
void compute_index(const T& out_idx,
const int axis_index,
const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim,
T& in_idx) const
{
in_idx = out_idx;
std::size_t idx = vec_indices.at(out_idx[axis_index]);
if(idx >= max_dim)
// for scalar output
if(lens.empty())
{
MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
return {type};
}
in_idx[axis_index] = idx;
return {type, lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;
int axis_index =
(axis < 0) ? static_cast<int>(args[0].get_shape().lens().size() + axis) : axis;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis_index];
std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, axis_index, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
if(output_shape.scalar())
{
output[0] = data[indices.front()];
}
else
{
auto out_lens = data.get_shape().lens();
out_lens[axis_index] = indices.get_shape().elements();
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) {
auto data_idx = out_idx;
data_idx[axis_index] = indices[data_idx[axis_index]];
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
data(data_idx.begin(), data_idx.end());
});
}
});
});
......@@ -820,10 +826,22 @@ struct dot
const shape& b = inputs.at(1);
auto t = a.type();
if(a.lens()[1] != b.lens()[0])
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2))
{
MIGRAPHX_THROW("DOT: dim values mismatch");
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
return {t, {a.lens()[0], b.lens()[1]}};
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens};
}
};
......@@ -932,6 +950,22 @@ struct softmax
}
};
struct logsoftmax
{
int axis = 1;
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
if(axis < 0 || axis > inputs[0].lens().size())
{
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range");
}
return inputs.at(0);
}
};
struct flatten
{
uint64_t axis = 0;
......@@ -1259,6 +1293,57 @@ struct gru
}
};
struct lstm
{
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
int input_forget = 0;
std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[2])
{
MIGRAPHX_THROW("LSTM: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if(direction == rnn_direction::bidirectional)
{
num_directions = 2;
}
if(num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("LSTM: num_direction does not match the direction attribute");
}
std::vector<std::size_t> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size;
return {inputs[0].type(), out_dims};
}
};
struct lstm_last_cell_output
{
std::string name() const { return "lstm_last_cell_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct undefined
{
std::string name() const { return "undefined"; }
......
......@@ -45,6 +45,18 @@ struct rewrite_rnn
const operation& actv_func2) const;
std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators
void apply_lstm(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const;
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -79,6 +79,7 @@ struct onnx_parser
add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax);
add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice);
......@@ -89,6 +90,7 @@ struct onnx_parser
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
// init the activation function map
......@@ -227,6 +229,19 @@ struct onnx_parser
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
}
instruction_ref parse_logsoftmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 1;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
}
instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -354,7 +369,9 @@ struct onnx_parser
}
if(args.size() == 2)
{
literal s = args[1]->get_literal();
auto s = args[1]->eval();
if(s.empty())
MIGRAPHX_THROW("Dynamic shape is not supported.");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
......@@ -433,7 +450,15 @@ struct onnx_parser
attribute_map attributes,
const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
literal v = parse_value(attributes.at("value"));
auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
{
migraphx::shape scalar_shape{v.get_shape().type()};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
}
return prog.add_literal(v);
}
......@@ -460,7 +485,12 @@ struct onnx_parser
{
transb = parse_value(attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm = {1, 0};
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3)
......@@ -480,6 +510,7 @@ struct onnx_parser
return add_broadcastable_binary_op(l3, l4, op::add{});
}
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
}
......@@ -749,15 +780,17 @@ struct onnx_parser
{
auto names = attributes.at("activations").strings();
vec_names.clear();
for_each(names.begin(), names.end(), [&](auto& fn) { vec_names.push_back(fn); });
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) {
if(map_actv_funcs.count(fn) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + std::string(fn) + " not supported");
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
}
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
......@@ -839,8 +872,7 @@ struct onnx_parser
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
std::copy(names.begin(), names.end(), vec_names.begin());
}
// need 4 activation functions
......@@ -878,12 +910,13 @@ struct onnx_parser
}
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0)
{
MIGRAPHX_THROW("GRU: activation function " + std::string(name) + " not supported");
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
......@@ -920,6 +953,178 @@ struct onnx_parser
return {hidden_states, last_output};
}
std::vector<instruction_ref>
parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
else if(direction == "forward")
{
dirct = op::rnn_direction::forward;
}
else
{
MIGRAPHX_THROW("LSTM: incorrect direction attribute");
}
std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
}
// need 6 activation functions for bidirectional directions
if(dirct == op::rnn_direction::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(vec_names.size())
{
case 1:
vec_names = {vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(1),
vec_names.at(0),
vec_names.at(1),
vec_names.at(1)};
break;
case 3:
// repeat all three actv funcs once
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(0),
vec_names.at(1),
vec_names.at(2)};
break;
case 4:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(3),
vec_names.at(3)};
break;
case 5:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(4),
vec_names.at(4)};
break;
default: break;
}
}
else
{
switch(vec_names.size())
{
case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
break;
default: break;
}
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
return map_actv_funcs[name];
});
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int input_forget = 0;
if(contains(attributes, "input_forget"))
{
input_forget = parse_value(attributes.at("input_forget")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 8 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
// third output for last cell output
auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);
return {hidden_states, last_output, last_cell_output};
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......@@ -960,9 +1165,9 @@ struct onnx_parser
instructions[name] = prog.add_parameter(name, s);
}
}
for(auto&& p : nodes)
for(auto&& output : graph.output())
{
this->parse_node(p.first);
this->parse_node(output.name());
}
}
......
......@@ -2,6 +2,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
......
......@@ -16,11 +16,14 @@ void rewrite_rnn::apply(program& prog) const
{
apply_vanilla_rnn(prog, ins);
}
if(ins->name() == "gru")
else if(ins->name() == "gru")
{
apply_gru(prog, ins);
}
else if(ins->name() == "lstm")
{
apply_lstm(prog, ins);
}
}
}
......@@ -664,5 +667,507 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
}
}
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
assert(ins->name() == "lstm");
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
migraphx::shape pph_shape{type, {1, 3 * hidden_size}};
std::vector<float> pph_data(pph_shape.elements(), 0.0);
auto actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::rnn_direction dirct = lstm_op.direction;
instruction_ref last_output{};
instruction_ref last_cell_output{};
if(dirct == op::rnn_direction::bidirectional)
{
// input weight matrix
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// process intial hidden state, it is the 6th argument
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process initial cell value
instruction_ref ic_forward{};
instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined")
{
ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
}
else
{
ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process weight of the peephole
instruction_ref pph_forward = prog.end();
instruction_ref pph_reverse = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined")
{
pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
}
auto ret_forward = lstm_cell(
true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2));
auto ret_reverse = lstm_cell(
false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
actv_funcs.at(3),
actv_funcs.at(4),
actv_funcs.at(5));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// last cell output
last_cell_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_reverse[2]);
// the following logic is to ensure the last instruction is a concat
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dirct == op::rnn_direction::forward);
// weight matrices
auto w = args[1];
auto r = args[2];
// bias
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// initial hidden state
instruction_ref ih{};
if(args.size() >= 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// initial cell value
instruction_ref ic{};
if(args.size() >= 7 && args[6]->name() != "undefined")
{
ic = args[6];
}
else
{
ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process weight of the peephole
instruction_ref pph = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined")
{
pph = args[7];
}
auto ret = lstm_cell(is_forward,
prog,
ins,
{args[0], w, r, bias, ih, ic, pph},
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_cell_output = ret[2];
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// replace the corresponding lstm_last_output instruction
// with the last_output, and the lstm_last_cell_output with
// the last_cell_output. The while loop is to handle the case
// of multiple lstm_last_output and lstm_last_cell_output
// operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
auto last_cell_output_it = ins->outputs().begin();
while(last_cell_output_it != ins->outputs().end())
{
last_cell_output_it = std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "lstm_last_cell_output";
});
if(last_cell_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_cell_output_it, last_cell_output);
last_cell_output_it++;
}
}
}
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const
{
// must have 7 args in the input vector
assert(inputs.size() == 7);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
auto ic = inputs.at(5);
auto pph = inputs.at(6);
instruction_ref hidden_states = prog.end();
instruction_ref last_output{};
instruction_ref last_cell_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
std::vector<int64_t> perm{1, 0};
// w matrix
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wi = prog.insert_instruction(ins, op::transpose{perm}, wi);
auto wo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wo = prog.insert_instruction(ins, op::transpose{perm}, wo);
auto wf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wf = prog.insert_instruction(ins, op::transpose{perm}, wf);
auto wc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sw);
auto tran_wc = prog.insert_instruction(ins, op::transpose{perm}, wc);
// r matrix
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto ri = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_ri = prog.insert_instruction(ins, op::transpose{perm}, ri);
auto ro = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_ro = prog.insert_instruction(ins, op::transpose{perm}, ro);
auto rf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rf = prog.insert_instruction(ins, op::transpose{perm}, rf);
auto rc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sr);
auto tran_rc = prog.insert_instruction(ins, op::transpose{perm}, rc);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// initial cell state
auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
auto ic_shape = sic->get_shape();
// bias
instruction_ref bi_brcst{};
instruction_ref bo_brcst{};
instruction_ref bf_brcst{};
instruction_ref bc_brcst{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi);
bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bi);
auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho);
bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bo);
auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias);
auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf);
bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bf);
auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc);
bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bc);
}
// peep hole
instruction_ref pphi_brcst{};
instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{};
if(pph != prog.end())
{
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
}
for(long i = 0; i < seq_len; ++i)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
}
if(bias != prog.end())
{
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst);
}
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
if(pph != prog.end())
{
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
}
if(bias != prog.end())
{
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
}
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
auto ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
if(bias != prog.end())
{
ct_before_actv = prog.insert_instruction(ins, op::add{}, ct_before_actv, bc_brcst);
}
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct
auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic);
auto it_ct = prog.insert_instruction(ins, op::mul{}, it, ct);
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
last_cell_output = cellt;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
if(pph != prog.end())
{
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
}
if(bias != prog.end())
{
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
}
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct)
auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
auto ht = prog.insert_instruction(ins, op::mul{}, ot, h_cellt);
sic = cellt;
sih = ht;
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
if(i < seq_len - 1)
{
if(i == 0)
{
hidden_states = last_output;
}
else
{
auto concat_arg0 = is_forward ? hidden_states : last_output;
auto concat_arg1 = is_forward ? last_output : hidden_states;
hidden_states =
prog.insert_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
}
last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0}}, last_cell_output);
return {hidden_states, last_output, last_cell_output};
}
std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
auto lstm_op = any_cast<op::lstm>(ins->get_operator());
// before rewrite the lstm operator, need to ensure
// we have 6 actv funcs, even though a user does not
// specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions
const auto& actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::rnn_direction::bidirectional)
{
switch(num_actv_funcs)
{
case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1:
return {actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0)};
case 2:
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(1),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(1)};
case 3:
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2)};
case 4:
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(3),
actv_funcs.at(3)};
case 5:
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(4),
actv_funcs.at(4)};
default: return actv_funcs;
}
}
else
{
switch(num_actv_funcs)
{
case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
default: return actv_funcs;
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -19,7 +19,7 @@ struct shape_impl
shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) {}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
......
#include <migraphx/cpu/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/shape_for_each.hpp>
#include <blaze/math/CustomMatrix.h>
namespace migraphx {
......@@ -14,10 +15,13 @@ template <class T>
static auto make_mat(tensor_view<T> x)
{
const auto& s = x.get_shape();
assert(s.lens().size() == 2);
// assert(s.lens().size() == 2);
std::size_t n_dims = s.lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
if(s.transposed())
return matrix<T>{x.data(), s.lens()[1], s.lens()[0], s.strides()[1]};
return matrix<T>{x.data(), s.lens()[0], s.lens()[1], s.strides()[0]};
return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]};
return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]};
}
template <class T, class F>
......@@ -64,18 +68,24 @@ void migemm_impl(tensor_view<T> cmat,
float beta,
std::false_type)
{
auto m = cmat.get_shape().lens()[0];
auto n = cmat.get_shape().lens()[1];
auto k = amat.get_shape().lens()[1];
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[1] == bmat.get_shape().lens()[0]);
assert(m == amat.get_shape().lens()[0]);
assert(n == bmat.get_shape().lens()[1]);
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
dfor(m, n)([&](auto ii, auto jj) {
double s = cmat(ii, jj) * beta;
dfor(k)([&](auto kk) { s += amat(ii, kk) * bmat(kk, jj); });
cmat(ii, jj) = alpha * s;
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
}
......@@ -83,7 +93,18 @@ template <class T>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
auto lens = amat.get_shape().lens();
bool batch_mul =
std::accumulate(lens.begin(), lens.end(), std::size_t{1}, std::multiplies<std::size_t>()) ==
(*lens.rbegin()) * (*(lens.rbegin() + 1));
if(batch_mul)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}
else
{
migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{});
}
}
void migemm(
......
......@@ -7,6 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct pass;
namespace cpu {
struct target
......
......@@ -613,6 +613,75 @@ struct softmax2d
}
};
struct cpu_logsoftmax
{
op::logsoftmax op;
std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const
{
if(axis == 0)
{
return 0;
}
else
{
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis);
return batch_shape.index(batch_idx.begin(), batch_idx.end());
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto lens = output_shape.lens();
std::vector<std::size_t> batch_lens{};
if(op.axis == 0)
{
batch_lens.push_back(1);
}
else
{
batch_lens.insert(batch_lens.begin(), lens.begin(), lens.begin() + op.axis);
}
shape batch_shape{migraphx::shape::uint32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index];
});
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_sum[index] += std::exp(output(idx.begin(), idx.end()));
});
for(std::size_t i = 0; i < batch_sum.size(); ++i)
{
batch_sum[i] = std::log(batch_sum[i]);
}
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) -= batch_sum[index];
});
});
return result;
}
};
struct add_op
{
std::string name() const { return "add"; }
......@@ -723,6 +792,7 @@ struct cpu_apply
apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["gather"] = extend_op<cpu_gather, op::gather>();
apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
......
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/dead_code_elimination.hpp>
......
......@@ -26,6 +26,7 @@ add_library(migraphx_device
device/atan.cpp
device/add_relu.cpp
device/contiguous.cpp
device/logsoftmax.cpp
device/mul.cpp
device/concat.cpp
device/pad.cpp
......@@ -48,6 +49,7 @@ add_library(migraphx_gpu
pooling.cpp
convolution.cpp
softmax.cpp
logsoftmax.cpp
contiguous.cpp
concat.cpp
relu.cpp
......
#include <migraphx/gpu/abs.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/batchnorm.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/concat.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -16,20 +16,24 @@ argument gather(hipStream_t stream,
std::vector<migraphx::argument> args,
int axis)
{
int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis;
int axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data());
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i);
lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)];
const auto* indices_ptr = device_cast(indices.data());
auto* out_ptr = device_cast(output.data());
const auto* in_ptr = device_cast(input.data());
auto& input_shape = args[0].get_shape();
auto lens = input_shape.lens();
lens[axis_index] = args[1].get_shape().elements();
migraphx::shape out_comp_shape{output_shape.type(), lens};
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
gs_launch(stream, nelements)([=](auto ii) {
auto in_idx = desc_output.multi(ii);
in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
});
});
});
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis)
{
auto lens = output_shape.lens();
std::size_t batch_size = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<std::size_t>());
std::size_t n_dims = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
migraphx::shape comp_shape{output_shape.type(), {batch_size, n_dims}};
visit_all(args.back(), args.front())([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
// each thread is for one item in the batch
gs_launch(stream, batch_size)([=](auto i) {
std::size_t row_start = i * n_dims;
// get max
auto batch_max = input_ptr[row_start];
for(std::size_t j = 1; j < n_dims; ++j)
{
auto ind = row_start + j;
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[ind]));
}
for(std::size_t j = 0; j < n_dims; ++j)
{
auto ind = row_start + j;
output_ptr[ind] = input_ptr[ind] - batch_max;
}
auto batch_sum = ::exp(to_hip_type(output_ptr[row_start]));
for(std::size_t j = 1; j < n_dims; ++j)
{
auto ind = row_start + j;
batch_sum += ::exp(to_hip_type(output_ptr[ind]));
}
batch_sum = ::log(to_hip_type(batch_sum));
for(std::size_t j = 0; j < n_dims; ++j)
{
auto ind = row_start + j;
output_ptr[ind] -= batch_sum;
}
});
});
return args.back();
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment