"test/api/test_custom_op.cpp" did not exist on "cd165ebd76bebe37d0327a7f70b2cf32040c6b1a"
Commit 703d10e7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents a5d03696 21d4395e
#ifndef MIGRAPHX_GUARD_OPERATORS_POW_HPP
#define MIGRAPHX_GUARD_OPERATORS_POW_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct pow : binary<pow>
{
auto apply() const
{
return [](auto x, auto y) { return std::pow(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_RSQRT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RSQRT_HPP
#include <migraphx/op/unary.hpp>
#include <cmath>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct rsqrt : unary<rsqrt>
{
auto apply() const
{
return [](auto x) { return 1 / std::sqrt(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SQDIFF_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQDIFF_HPP
#include <migraphx/op/binary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sqdiff : binary<sqdiff>
{
auto apply() const
{
return [](auto x, auto y) { return (x - y) * (x - y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SQRT_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQRT_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sqrt : unary<sqrt>
{
auto apply() const
{
return [](auto x) { return std::sqrt(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -47,6 +47,7 @@ ...@@ -47,6 +47,7 @@
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/reduce_sum.hpp> #include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/relu.hpp> #include <migraphx/op/relu.hpp>
...@@ -54,12 +55,15 @@ ...@@ -54,12 +55,15 @@
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp> #include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_output.hpp> #include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
#include <migraphx/op/sin.hpp> #include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp> #include <migraphx/op/slice.hpp>
#include <migraphx/op/softmax.hpp> #include <migraphx/op/softmax.hpp>
#include <migraphx/op/sqrt.hpp>
#include <migraphx/op/sqdiff.hpp>
#include <migraphx/op/squeeze.hpp> #include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp> #include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp> #include <migraphx/op/tanh.hpp>
......
...@@ -54,11 +54,13 @@ struct onnx_parser ...@@ -54,11 +54,13 @@ struct onnx_parser
add_generic_op("Asin", op::asin{}); add_generic_op("Asin", op::asin{});
add_generic_op("Acos", op::acos{}); add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{}); add_generic_op("Atan", op::atan{});
add_generic_op("Sqrt", op::sqrt{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{}); add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{}); add_binary_op("Sub", op::sub{});
add_binary_op("Pow", op::pow{});
add_variadic_op("Sum", op::add{}); add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
...@@ -66,11 +68,13 @@ struct onnx_parser ...@@ -66,11 +68,13 @@ struct onnx_parser
add_mem_op("ArgMax", &onnx_parser::parse_argmax); add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin); add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu); add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
...@@ -91,6 +95,7 @@ struct onnx_parser ...@@ -91,6 +95,7 @@ struct onnx_parser
add_mem_op("Gather", &onnx_parser::parse_gather); add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape); add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn); add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru); add_mem_op("GRU", &onnx_parser::parse_gru);
...@@ -462,8 +467,7 @@ struct onnx_parser ...@@ -462,8 +467,7 @@ struct onnx_parser
if(args.size() == 2) if(args.size() == 2)
{ {
auto s = args[1]->eval(); auto s = args[1]->eval();
if(s.empty()) check_arg_empty(s, "Reshape: dynamic shape is not supported");
MIGRAPHX_THROW("Dynamic shape is not supported.");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
...@@ -543,6 +547,12 @@ struct onnx_parser ...@@ -543,6 +547,12 @@ struct onnx_parser
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_value(attributes.at("value")); literal v = parse_value(attributes.at("value"));
// return empty literal
if(v.get_shape().elements() == 0)
{
return prog.add_literal(literal{});
}
auto dim_size = attributes.at("value").t().dims_size(); auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(dim_size == 0)
...@@ -870,10 +880,7 @@ struct onnx_parser ...@@ -870,10 +880,7 @@ struct onnx_parser
} }
migraphx::argument in = args[0]->eval(); migraphx::argument in = args[0]->eval();
if(in.empty()) check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
...@@ -901,6 +908,74 @@ struct onnx_parser ...@@ -901,6 +908,74 @@ struct onnx_parser
} }
} }
instruction_ref parse_constant_of_shape(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
literal l_val{};
if(contains(attributes, "value"))
{
l_val = parse_value(attributes.at("value"));
if(l_val.get_shape().elements() != 1)
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
}
}
else
{
l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
}
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if(args.empty())
{
MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
}
else
{
migraphx::shape s;
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), *val.begin());
l_out = literal(s, out_vec);
});
return prog.add_literal(l_out);
}
}
instruction_ref
parse_expand(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
}
std::vector<instruction_ref> std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -1323,6 +1398,19 @@ struct onnx_parser ...@@ -1323,6 +1398,19 @@ struct onnx_parser
} }
} }
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
if(!contains(attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
int to_type = parse_value(attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return prog.add_instruction(op::convert{type}, std::move(args));
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
...@@ -1635,6 +1723,14 @@ struct onnx_parser ...@@ -1635,6 +1723,14 @@ struct onnx_parser
} }
} }
} }
void check_arg_empty(const argument& arg, const std::string& msg)
{
if(arg.empty())
{
MIGRAPHX_THROW(msg);
}
}
}; };
program parse_onnx(const std::string& name) program parse_onnx(const std::string& name)
......
...@@ -38,9 +38,14 @@ add_library(migraphx_device ...@@ -38,9 +38,14 @@ add_library(migraphx_device
device/gather.cpp device/gather.cpp
device/sub.cpp device/sub.cpp
device/pack.cpp device/pack.cpp
device/div.cpp
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp device/reduce_sum.cpp
device/rsqrt.cpp
device/sqrt.cpp
device/reduce_mean.cpp device/reduce_mean.cpp
device/pow.cpp
device/sqdiff.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
......
#include <migraphx/gpu/device/div.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return x / y; });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/pow.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)(
[](auto b, auto e) { return ::pow(to_hip_type(b), to_hip_type(e)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -8,6 +8,7 @@ namespace device { ...@@ -8,6 +8,7 @@ namespace device {
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg) void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
{ {
reduce(stream, result, arg, sum{}, 0, id{}, id{}); reduce(stream, result, arg, sum{}, 0, id{}, id{});
} }
......
#include <migraphx/gpu/device/rsqrt.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rsqrt(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) __device__ { return ::rsqrt(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/sqdiff.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqdiff(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return (x - y) * (x - y); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/sqrt.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqrt(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::sqrt(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -8,7 +8,7 @@ namespace device { ...@@ -8,7 +8,7 @@ namespace device {
void sub(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void sub(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) { return y - x; }); nary(stream, result, arg1, arg2)([](auto x, auto y) { return x - y; });
} }
} // namespace device } // namespace device
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_DIV_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_DIV_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_POW_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_POW_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_RSQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_RSQRT_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rsqrt(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SQDIFF_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SQDIFF_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqdiff(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SQRT_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqrt(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DIV_HPP
#define MIGRAPHX_GUARD_RTGLIB_DIV_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/div.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_div : binary_device<hip_div, device::div>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment