"vscode:/vscode.git/clone" did not exist on "e384b83f29ac30c422aff5833c413e6f7d3f7b08"
Commit a7408288 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into tf_pb

parents fa983162 c1fec2c4
......@@ -27,6 +27,13 @@ void eliminate_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if(ins->name() == "reshape")
{
continue;
}
// Make a copy so we can modify it while we iterate
auto args = ins->inputs();
for(auto arg : ins->inputs())
......
......@@ -22,8 +22,8 @@ struct literal : raw_data<literal>
{
literal() {}
template <class U, class T = deduce<U>>
literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(shape::get_type<T>{})
template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}>
literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType)
{
static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
*(reinterpret_cast<T*>(buffer.get())) = x;
......
......@@ -758,42 +758,48 @@ struct gather
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type();
lens[axis_index] = inputs[1].elements();
return {type, lens};
lens.erase(lens.begin() + axis_index);
if(!inputs[1].scalar())
{
auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
}
template <class T>
void compute_index(const T& out_idx,
const int axis_index,
const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim,
T& in_idx) const
// for scalar output
if(lens.empty())
{
in_idx = out_idx;
std::size_t idx = vec_indices.at(out_idx[axis_index]);
if(idx >= max_dim)
{
MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
return {type};
}
in_idx[axis_index] = idx;
return {type, lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;
int axis_index =
(axis < 0) ? static_cast<int>(args[0].get_shape().lens().size() + axis) : axis;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis_index];
std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, axis_index, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
if(output_shape.scalar())
{
output[0] = data[indices.front()];
}
else
{
auto out_lens = data.get_shape().lens();
out_lens[axis_index] = indices.get_shape().elements();
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) {
auto data_idx = out_idx;
data_idx[axis_index] = indices[data_idx[axis_index]];
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
data(data_idx.begin(), data_idx.end());
});
}
});
});
......@@ -820,10 +826,22 @@ struct dot
const shape& b = inputs.at(1);
auto t = a.type();
if(a.lens()[1] != b.lens()[0])
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2))
{
MIGRAPHX_THROW("DOT: dim values mismatch");
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
return {t, {a.lens()[0], b.lens()[1]}};
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens};
}
};
......@@ -932,6 +950,22 @@ struct softmax
}
};
struct logsoftmax
{
int axis = 1;
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
if(axis < 0 || axis > inputs[0].lens().size())
{
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range");
}
return inputs.at(0);
}
};
struct flatten
{
uint64_t axis = 0;
......@@ -1259,6 +1293,57 @@ struct gru
}
};
struct lstm
{
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
int input_forget = 0;
std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[2])
{
MIGRAPHX_THROW("LSTM: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if(direction == rnn_direction::bidirectional)
{
num_directions = 2;
}
if(num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("LSTM: num_direction does not match the direction attribute");
}
std::vector<std::size_t> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size;
return {inputs[0].type(), out_dims};
}
};
struct lstm_last_cell_output
{
std::string name() const { return "lstm_last_cell_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct undefined
{
std::string name() const { return "undefined"; }
......
......@@ -45,6 +45,18 @@ struct rewrite_rnn
const operation& actv_func2) const;
std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators
void apply_lstm(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const;
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -79,6 +79,7 @@ struct onnx_parser
add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax);
add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice);
......@@ -89,6 +90,7 @@ struct onnx_parser
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
// init the activation function map
......@@ -227,6 +229,19 @@ struct onnx_parser
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
}
instruction_ref parse_logsoftmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 1;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
}
instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -354,7 +369,9 @@ struct onnx_parser
}
if(args.size() == 2)
{
literal s = args[1]->get_literal();
auto s = args[1]->eval();
if(s.empty())
MIGRAPHX_THROW("Dynamic shape is not supported.");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
......@@ -434,6 +451,14 @@ struct onnx_parser
const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
{
migraphx::shape scalar_shape{v.get_shape().type()};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
}
return prog.add_literal(v);
}
......@@ -460,7 +485,12 @@ struct onnx_parser
{
transb = parse_value(attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm = {1, 0};
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3)
......@@ -480,6 +510,7 @@ struct onnx_parser
return add_broadcastable_binary_op(l3, l4, op::add{});
}
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
}
......@@ -749,15 +780,17 @@ struct onnx_parser
{
auto names = attributes.at("activations").strings();
vec_names.clear();
for_each(names.begin(), names.end(), [&](auto& fn) { vec_names.push_back(fn); });
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) {
if(map_actv_funcs.count(fn) == 0)
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("RNN: activation function " + std::string(fn) + " not supported");
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
}
});
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
......@@ -839,8 +872,7 @@ struct onnx_parser
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
std::copy(names.begin(), names.end(), vec_names.begin());
}
// need 4 activation functions
......@@ -878,12 +910,13 @@ struct onnx_parser
}
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0)
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("GRU: activation function " + std::string(name) + " not supported");
MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
}
});
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
......@@ -920,6 +953,178 @@ struct onnx_parser
return {hidden_states, last_output};
}
std::vector<instruction_ref>
parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
else if(direction == "forward")
{
dirct = op::rnn_direction::forward;
}
else
{
MIGRAPHX_THROW("LSTM: incorrect direction attribute");
}
std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
}
// need 6 activation functions for bidirectional directions
if(dirct == op::rnn_direction::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(vec_names.size())
{
case 1:
vec_names = {vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(1),
vec_names.at(0),
vec_names.at(1),
vec_names.at(1)};
break;
case 3:
// repeat all three actv funcs once
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(0),
vec_names.at(1),
vec_names.at(2)};
break;
case 4:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(3),
vec_names.at(3)};
break;
case 5:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(4),
vec_names.at(4)};
break;
default: break;
}
}
else
{
switch(vec_names.size())
{
case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
break;
default: break;
}
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
return map_actv_funcs[name];
});
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int input_forget = 0;
if(contains(attributes, "input_forget"))
{
input_forget = parse_value(attributes.at("input_forget")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 8 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
// third output for last cell output
auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);
return {hidden_states, last_output, last_cell_output};
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......@@ -960,9 +1165,9 @@ struct onnx_parser
instructions[name] = prog.add_parameter(name, s);
}
}
for(auto&& p : nodes)
for(auto&& output : graph.output())
{
this->parse_node(p.first);
this->parse_node(output.name());
}
}
......
......@@ -2,6 +2,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
......
This diff is collapsed.
......@@ -19,7 +19,7 @@ struct shape_impl
shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) {}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
......
#include <migraphx/cpu/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/shape_for_each.hpp>
#include <blaze/math/CustomMatrix.h>
namespace migraphx {
......@@ -14,10 +15,13 @@ template <class T>
static auto make_mat(tensor_view<T> x)
{
const auto& s = x.get_shape();
assert(s.lens().size() == 2);
// assert(s.lens().size() == 2);
std::size_t n_dims = s.lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
if(s.transposed())
return matrix<T>{x.data(), s.lens()[1], s.lens()[0], s.strides()[1]};
return matrix<T>{x.data(), s.lens()[0], s.lens()[1], s.strides()[0]};
return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]};
return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]};
}
template <class T, class F>
......@@ -64,18 +68,24 @@ void migemm_impl(tensor_view<T> cmat,
float beta,
std::false_type)
{
auto m = cmat.get_shape().lens()[0];
auto n = cmat.get_shape().lens()[1];
auto k = amat.get_shape().lens()[1];
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[1] == bmat.get_shape().lens()[0]);
assert(m == amat.get_shape().lens()[0]);
assert(n == bmat.get_shape().lens()[1]);
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
dfor(m, n)([&](auto ii, auto jj) {
double s = cmat(ii, jj) * beta;
dfor(k)([&](auto kk) { s += amat(ii, kk) * bmat(kk, jj); });
cmat(ii, jj) = alpha * s;
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
}
......@@ -83,7 +93,18 @@ template <class T>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{
auto lens = amat.get_shape().lens();
bool batch_mul =
std::accumulate(lens.begin(), lens.end(), std::size_t{1}, std::multiplies<std::size_t>()) ==
(*lens.rbegin()) * (*(lens.rbegin() + 1));
if(batch_mul)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}
else
{
migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{});
}
}
void migemm(
......
......@@ -7,6 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct pass;
namespace cpu {
struct target
......
......@@ -613,6 +613,75 @@ struct softmax2d
}
};
struct cpu_logsoftmax
{
op::logsoftmax op;
std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const
{
if(axis == 0)
{
return 0;
}
else
{
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis);
return batch_shape.index(batch_idx.begin(), batch_idx.end());
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto lens = output_shape.lens();
std::vector<std::size_t> batch_lens{};
if(op.axis == 0)
{
batch_lens.push_back(1);
}
else
{
batch_lens.insert(batch_lens.begin(), lens.begin(), lens.begin() + op.axis);
}
shape batch_shape{migraphx::shape::uint32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index];
});
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_sum[index] += std::exp(output(idx.begin(), idx.end()));
});
for(std::size_t i = 0; i < batch_sum.size(); ++i)
{
batch_sum[i] = std::log(batch_sum[i]);
}
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) -= batch_sum[index];
});
});
return result;
}
};
struct add_op
{
std::string name() const { return "add"; }
......@@ -723,6 +792,7 @@ struct cpu_apply
apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["gather"] = extend_op<cpu_gather, op::gather>();
apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
......
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/dead_code_elimination.hpp>
......
......@@ -26,6 +26,7 @@ add_library(migraphx_device
device/atan.cpp
device/add_relu.cpp
device/contiguous.cpp
device/logsoftmax.cpp
device/mul.cpp
device/concat.cpp
device/pad.cpp
......@@ -48,6 +49,7 @@ add_library(migraphx_gpu
pooling.cpp
convolution.cpp
softmax.cpp
logsoftmax.cpp
contiguous.cpp
concat.cpp
relu.cpp
......
#include <migraphx/gpu/abs.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/batchnorm.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/concat.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -16,20 +16,24 @@ argument gather(hipStream_t stream,
std::vector<migraphx::argument> args,
int axis)
{
int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis;
int axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data());
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i);
lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)];
auto* out_ptr = device_cast(output.data());
const auto* in_ptr = device_cast(input.data());
auto& input_shape = args[0].get_shape();
auto lens = input_shape.lens();
lens[axis_index] = args[1].get_shape().elements();
migraphx::shape out_comp_shape{output_shape.type(), lens};
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
gs_launch(stream, nelements)([=](auto ii) {
auto in_idx = desc_output.multi(ii);
in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
});
});
});
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis)
{
auto lens = output_shape.lens();
std::size_t batch_size = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<std::size_t>());
std::size_t n_dims = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
migraphx::shape comp_shape{output_shape.type(), {batch_size, n_dims}};
visit_all(args.back(), args.front())([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
// each thread is for one item in the batch
gs_launch(stream, batch_size)([=](auto i) {
std::size_t row_start = i * n_dims;
// get max
auto batch_max = input_ptr[row_start];
for(std::size_t j = 1; j < n_dims; ++j)
{
auto ind = row_start + j;
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[ind]));
}
for(std::size_t j = 0; j < n_dims; ++j)
{
auto ind = row_start + j;
output_ptr[ind] = input_ptr[ind] - batch_max;
}
auto batch_sum = ::exp(to_hip_type(output_ptr[row_start]));
for(std::size_t j = 1; j < n_dims; ++j)
{
auto ind = row_start + j;
batch_sum += ::exp(to_hip_type(output_ptr[ind]));
}
batch_sum = ::log(to_hip_type(batch_sum));
for(std::size_t j = 0; j < n_dims; ++j)
{
auto ind = row_start + j;
output_ptr[ind] -= batch_sum;
}
});
});
return args.back();
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment