Commit 784dc2aa authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into multiply-add

parents 70641651 dcbc9255
......@@ -81,6 +81,7 @@ rocm_enable_clang_tidy(
-modernize-use-override
-modernize-pass-by-value
-modernize-use-default-member-init
-modernize-use-trailing-return-type
-modernize-use-transparent-functors
-performance-type-promotion-in-math-fn
-readability-braces-around-statements
......
......@@ -20,6 +20,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
clang-format-5.0 \
clang-tidy-5.0 \
cmake \
comgr \
curl \
doxygen \
g++-7 \
......@@ -32,14 +33,16 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libncurses5-dev \
libnuma-dev \
libpthread-stubs0-dev \
libssl-dev \
python \
python-dev \
python-pip \
rocm-device-libs \
rocm-opencl \
rocm-opencl-dev \
rocminfo \
software-properties-common \
wget && \
wget \
zlib1g-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
......@@ -50,7 +53,7 @@ RUN pip install cget
RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz
# Install hcc
RUN rclone -b roc-2.3.x -c fd93baed7dcc4fe8019b5fdc90213bfe7c298245 https://github.com/RadeonOpenCompute/hcc.git /hcc
RUN rclone -b roc-2.6.x -c 0f4c96b7851af2663a7f3ac16ecfb76c7c78a5bf https://github.com/RadeonOpenCompute/hcc.git /hcc
RUN cget -p $PREFIX install hcc,/hcc
# Use hcc
......
#ifndef MIGRAPHX_GUARD_OPERATORS_POW_HPP
#define MIGRAPHX_GUARD_OPERATORS_POW_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct pow : binary<pow>
{
auto apply() const
{
return [](auto x, auto y) { return std::pow(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reduce_mean
{
std::vector<std::int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reduce_mean"; }
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for(auto& axis : tuned_axes)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(axis >= s_dim or axis < -s_dim)
{
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
}
if(axis < 0)
{
axis += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{
lens[axis] = 1;
}
return {s.type(), lens};
}
template <class T>
void calc_mean(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : tuned_axes)
{
data_idx[axis] = b_idx[axis];
}
val += input(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) = val / batch_shape.elements();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : tuned_axes)
{
batch_lens[axis] = arg_lens[axis];
}
shape batch_shape{output_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
this->calc_mean(input, batch_shape, tuned_axes, out_idx, output);
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -4,6 +4,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <vector>
......@@ -13,7 +14,7 @@ namespace op {
struct reduce_sum
{
std::vector<std::size_t> axes;
std::vector<int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -23,25 +24,82 @@ struct reduce_sum
std::string name() const { return "reduce_sum"; }
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for(auto& axis : tuned_axes)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(axis >= s_dim or axis < -s_dim)
{
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
}
if(axis < 0)
{
axis += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{
lens[axis] = 1;
}
return {s.type(), lens};
}
template <class T>
void calc_sum(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : tuned_axes)
{
data_idx[axis] = b_idx[axis];
}
val += input(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) = val;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
std::vector<int64_t> tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : tuned_axes)
{
batch_lens[axis] = arg_lens[axis];
}
shape batch_shape{output_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](auto&& in_idx) {
auto out_idx = in_idx;
for(auto axis : axes)
out_idx[axis] = 0;
output(out_idx.begin(), out_idx.end()) += input(in_idx.begin(), in_idx.end());
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
this->calc_sum(input, batch_shape, tuned_axes, out_idx, output);
});
});
......
#ifndef MIGRAPHX_GUARD_OPERATORS_RSQRT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RSQRT_HPP
#include <migraphx/op/unary.hpp>
#include <cmath>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct rsqrt : unary<rsqrt>
{
auto apply() const
{
return [](auto x) { return 1 / std::sqrt(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SQDIFF_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQDIFF_HPP
#include <migraphx/op/binary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sqdiff : binary<sqdiff>
{
auto apply() const
{
return [](auto x, auto y) { return (x - y) * (x - y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SQRT_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQRT_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sqrt : unary<sqrt>
{
auto apply() const
{
return [](auto x) { return std::sqrt(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -45,18 +45,23 @@
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sinh.hpp>
#include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/sqrt.hpp>
#include <migraphx/op/sqdiff.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp>
......
......@@ -54,11 +54,13 @@ struct onnx_parser
add_generic_op("Asin", op::asin{});
add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{});
add_generic_op("Sqrt", op::sqrt{});
add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{});
add_binary_op("Pow", op::pow{});
add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{});
......@@ -66,11 +68,13 @@ struct onnx_parser
add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
......@@ -91,12 +95,14 @@ struct onnx_parser
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
// init the activation function map
init_actv_func();
......@@ -461,8 +467,7 @@ struct onnx_parser
if(args.size() == 2)
{
auto s = args[1]->eval();
if(s.empty())
MIGRAPHX_THROW("Dynamic shape is not supported.");
check_arg_empty(s, "Reshape: dynamic shape is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
......@@ -542,6 +547,12 @@ struct onnx_parser
const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
// return empty literal
if(v.get_shape().elements() == 0)
{
return prog.add_literal(literal{});
}
auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
......@@ -869,10 +880,7 @@ struct onnx_parser
}
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
......@@ -900,6 +908,74 @@ struct onnx_parser
}
}
instruction_ref parse_constant_of_shape(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
literal l_val{};
if(contains(attributes, "value"))
{
l_val = parse_value(attributes.at("value"));
if(l_val.get_shape().elements() != 1)
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
}
}
else
{
l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
}
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if(args.empty())
{
MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
}
else
{
migraphx::shape s;
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), *val.begin());
l_out = literal(s, out_vec);
});
return prog.add_literal(l_out);
}
}
instruction_ref
parse_expand(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
}
std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -1288,20 +1364,21 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output};
}
instruction_ref parse_reduce_sum(const std::string&,
template <class T>
instruction_ref parse_reduce_oper(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim);
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
......@@ -1312,16 +1389,28 @@ struct onnx_parser
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_sum{axes}, std::move(args));
return prog.add_instruction(T{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
std::vector<int64_t> squeeze_axes{axes.begin(), axes.end()};
return prog.add_instruction(op::squeeze{squeeze_axes}, ins);
auto ins = prog.add_instruction(T{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
if(!contains(attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
int to_type = parse_value(attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return prog.add_instruction(op::convert{type}, std::move(args));
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......@@ -1468,16 +1557,16 @@ struct onnx_parser
{
switch(attr.type())
{
case onnx::AttributeProto::UNDEFINED: return {};
case onnx::AttributeProto::FLOAT: return literal{attr.f()};
case onnx::AttributeProto::INT: return literal{attr.i()};
case onnx::AttributeProto::STRING: return {};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
case onnx::AttributeProto::STRINGS: return {};
case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::UNDEFINED:
case onnx::AttributeProto::GRAPH:
case onnx::AttributeProto::STRING:
case onnx::AttributeProto::STRINGS:
case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::GRAPHS: return {};
}
MIGRAPHX_THROW("Invalid attribute type");
......@@ -1491,47 +1580,41 @@ struct onnx_parser
const std::string& s = t.raw_data();
switch(t.data_type())
{
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT16:
return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::FLOAT16:
return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::INT8:
case onnx::TensorProto::UINT16:
case onnx::TensorProto::INT16:
case onnx::TensorProto::INT32:
case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT8:
case onnx::TensorProto::STRING:
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::UINT32:
case onnx::TensorProto::UINT64:
case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
MIGRAPHX_THROW("Invalid tensor type");
}
switch(t.data_type())
{
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT:
return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT16:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT16:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT32:
case onnx::TensorProto::BOOL:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT64:
return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT:
return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT16:
{
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
......@@ -1542,11 +1625,12 @@ struct onnx_parser
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half);
}
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::UINT8:
case onnx::TensorProto::STRING:
case onnx::TensorProto::UINT32:
case onnx::TensorProto::UINT64:
case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
MIGRAPHX_THROW("Invalid tensor type");
......@@ -1574,28 +1658,23 @@ struct onnx_parser
shape::type_t shape_type{};
switch(t.tensor_type().elem_type())
{
case onnx::TensorProto::UNDEFINED:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
case onnx::TensorProto::UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
case onnx::TensorProto::STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
case onnx::TensorProto::UINT8:
case onnx::TensorProto::STRING:
case onnx::TensorProto::BOOL:
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case onnx::TensorProto::COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
break; // throw std::runtime_error("Unsupported type");
}
std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim();
......@@ -1634,6 +1713,14 @@ struct onnx_parser
}
}
}
void check_arg_empty(const argument& arg, const std::string& msg)
{
if(arg.empty())
{
MIGRAPHX_THROW(msg);
}
}
};
program parse_onnx(const std::string& name)
......
......@@ -37,8 +37,14 @@ add_library(migraphx_device
device/pad.cpp
device/gather.cpp
device/sub.cpp
device/div.cpp
device/clip.cpp
device/reduce_sum.cpp
device/rsqrt.cpp
device/sqrt.cpp
device/reduce_mean.cpp
device/pow.cpp
device/sqdiff.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device)
......@@ -77,6 +83,7 @@ add_library(migraphx_gpu
adjust_allocation.cpp
clip.cpp
reduce_sum.cpp
reduce_mean.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu)
......
#include <migraphx/gpu/device/div.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return x / y; });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -28,6 +28,16 @@ struct id
}
};
struct mean
{
size_t item_num = 1;
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return static_cast<T>(x / item_num);
}
};
struct max
{
template <class T, class U>
......
#include <migraphx/gpu/device/pow.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)(
[](auto b, auto e) { return ::pow(to_hip_type(b), to_hip_type(e)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/reduce_mean.hpp>
#include <migraphx/gpu/device/reduce.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg)
{
std::size_t item_num = arg.get_shape().elements() / result.get_shape().elements();
reduce(stream, result, arg, sum{}, 0, id{}, mean{item_num});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -8,6 +8,7 @@ namespace device {
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
{
reduce(stream, result, arg, sum{}, 0, id{}, id{});
}
......
#include <migraphx/gpu/device/rsqrt.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rsqrt(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) __device__ { return ::rsqrt(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/sqdiff.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqdiff(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return (x - y) * (x - y); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment