Commit 19b75f4d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from branch_bert operators

parents e9d4d21e badacbcc
...@@ -81,15 +81,7 @@ template <class T> ...@@ -81,15 +81,7 @@ template <class T>
std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0) std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0)
{ {
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
shape::type_t type = s.type(); std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
if(type == shape::int64_type or type == shape::int32_type)
{
std::generate(result.begin(), result.end(), [] { return 1; });
}
else
{
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
}
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; }); // std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; }); // std::generate(result.begin(), result.end(), []{ return 1; });
......
#ifndef MIGRAPHX_GUARD_OPERATORS_RSQRT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RSQRT_HPP
#include <migraphx/op/unary.hpp>
#include <cmath>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct rsqrt : unary<rsqrt>
{
auto apply() const
{
return [](auto x) { return 1 / std::sqrt(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp> #include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_output.hpp> #include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
......
...@@ -86,8 +86,8 @@ struct onnx_parser ...@@ -86,8 +86,8 @@ struct onnx_parser
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("MatMul", &onnx_parser::parse_matmul); add_mem_op("MatMul", &onnx_parser::parse_matmul);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax); add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax); add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze); add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("Slice", &onnx_parser::parse_slice);
...@@ -261,6 +261,7 @@ struct onnx_parser ...@@ -261,6 +261,7 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
template<class Op>
instruction_ref parse_softmax(const std::string&, instruction_ref parse_softmax(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
...@@ -271,20 +272,7 @@ struct onnx_parser ...@@ -271,20 +272,7 @@ struct onnx_parser
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
} }
return prog.add_instruction(op::softmax{axis}, std::move(args)); return prog.add_instruction(Op{axis}, std::move(args));
}
instruction_ref parse_logsoftmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 1;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
} }
instruction_ref parse_argmax(const std::string&, instruction_ref parse_argmax(const std::string&,
...@@ -478,8 +466,7 @@ struct onnx_parser ...@@ -478,8 +466,7 @@ struct onnx_parser
if(args.size() == 2) if(args.size() == 2)
{ {
auto s = args[1]->eval(); auto s = args[1]->eval();
if(s.empty()) check_arg_empty(s, "Reshape: dynamic shape is not supported");
MIGRAPHX_THROW("Dynamic shape is not supported.");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
...@@ -898,10 +885,7 @@ struct onnx_parser ...@@ -898,10 +885,7 @@ struct onnx_parser
} }
migraphx::argument in = args[0]->eval(); migraphx::argument in = args[0]->eval();
if(in.empty()) check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
...@@ -952,7 +936,7 @@ struct onnx_parser ...@@ -952,7 +936,7 @@ struct onnx_parser
if(args.empty()) if(args.empty())
{ {
MIGRAPHX_THROW("Parse ConstantOfShape : must have 1 input!"); MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
} }
else else
{ {
...@@ -965,19 +949,22 @@ struct onnx_parser ...@@ -965,19 +949,22 @@ struct onnx_parser
else else
{ {
migraphx::argument in = args[0]->eval(); migraphx::argument in = args[0]->eval();
if(in.empty()) check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
{
MIGRAPHX_THROW("Parse ConstantOfShape: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims}; s = migraphx::shape{type, dims};
} }
literal l_out; literal l_out{};
l_val.visit([&](auto val) { l_val.visit([&](auto val) {
// this #ifdef is to avoid a false cppcheck error, will remove later
// when a newer version of cppcheck is used
#ifdef CPPCHECK
using type = float;
#else
using type = std::remove_cv_t<typename decltype(val)::value_type>; using type = std::remove_cv_t<typename decltype(val)::value_type>;
#endif
// l_val contains only one element // l_val contains only one element
std::vector<type> out_vec(s.elements(), *val.begin()); std::vector<type> out_vec(s.elements(), *val.begin());
l_out = literal(s, out_vec); l_out = literal(s, out_vec);
...@@ -992,10 +979,7 @@ struct onnx_parser ...@@ -992,10 +979,7 @@ struct onnx_parser
{ {
auto in_lens = args[0]->get_shape().lens(); auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval(); migraphx::argument arg_s = args[1]->eval();
if(arg_s.empty()) check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
{
MIGRAPHX_THROW("Parse Expand: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims); auto out_lens = compute_broadcasted_lens(in_lens, dims);
...@@ -1749,6 +1733,14 @@ struct onnx_parser ...@@ -1749,6 +1733,14 @@ struct onnx_parser
} }
} }
} }
void check_arg_empty(const argument& arg, const std::string& msg)
{
if(arg.empty())
{
MIGRAPHX_THROW(msg);
}
}
}; };
program parse_onnx(const std::string& name) program parse_onnx(const std::string& name)
......
...@@ -40,6 +40,7 @@ add_library(migraphx_device ...@@ -40,6 +40,7 @@ add_library(migraphx_device
device/div.cpp device/div.cpp
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp device/reduce_sum.cpp
device/rsqrt.cpp
device/sqrt.cpp device/sqrt.cpp
device/reduce_mean.cpp device/reduce_mean.cpp
device/pow.cpp device/pow.cpp
......
#include <migraphx/gpu/device/rsqrt.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rsqrt(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) __device__ { return ::rsqrt(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_RSQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_RSQRT_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rsqrt(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_RSQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_RSQRT_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/rsqrt.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_rsqrt : unary_device<hip_rsqrt, device::rsqrt>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#include <migraphx/gpu/convert.hpp> #include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp> #include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp> #include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp> #include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp> #include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp> #include <migraphx/gpu/pow.hpp>
...@@ -107,6 +108,7 @@ struct miopen_apply ...@@ -107,6 +108,7 @@ struct miopen_apply
add_generic_op<hip_div>("div"); add_generic_op<hip_div>("div");
add_generic_op<hip_max>("max"); add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min"); add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_pow>("pow"); add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff"); add_generic_op<hip_sqdiff>("sqdiff");
......
...@@ -154,6 +154,7 @@ struct tf_parser ...@@ -154,6 +154,7 @@ struct tf_parser
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0}); add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{});
add_generic_op("Tanh", op::tanh{}); add_generic_op("Tanh", op::tanh{});
add_generic_op("StopGradient", op::identity{}); add_generic_op("StopGradient", op::identity{});
...@@ -179,6 +180,7 @@ struct tf_parser ...@@ -179,6 +180,7 @@ struct tf_parser
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
add_mem_op("Transpose", &tf_parser::parse_transpose, false);
} }
template <class F> template <class F>
...@@ -769,6 +771,16 @@ struct tf_parser ...@@ -769,6 +771,16 @@ struct tf_parser
return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0)); return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
} }
instruction_ref
parse_transpose(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto perm = args[1]->eval().get<int32_t>().to_vector();
op::transpose op;
op.dims = std::vector<int64_t>(perm.begin(), perm.end());
return prog.add_instruction(op, args.front());
}
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
{ {
nodes = get_nodes(graph, input_nodes); nodes = get_nodes(graph, input_nodes);
......
...@@ -1808,6 +1808,20 @@ TEST_CASE(reduce_sum_axis12) ...@@ -1808,6 +1808,20 @@ TEST_CASE(reduce_sum_axis12)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(rsqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraphx::literal{s, {4.0, 16.0, 64.0}});
p.add_instruction(migraphx::op::rsqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_mean_axis1) TEST_CASE(reduce_mean_axis1)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -3570,6 +3570,19 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half> ...@@ -3570,6 +3570,19 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
}; };
}; };
struct test_rsqrt : verify_program<test_rsqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
auto x = p.add_parameter("x", s);
auto l0 = p.add_instruction(migraphx::op::clip{std::numeric_limits<float>::max(), 1.0}, x);
p.add_instruction(migraphx::op::rsqrt{}, l0);
return p;
};
};
struct test_reduce_mean : verify_program<test_reduce_mean> struct test_reduce_mean : verify_program<test_reduce_mean>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -933,7 +933,7 @@ TEST_CASE(cast_test) ...@@ -933,7 +933,7 @@ TEST_CASE(cast_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(const_of_shape1) TEST_CASE(const_of_shape_float)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3}); migraphx::shape ss(migraphx::shape::int32_type, {3});
...@@ -946,20 +946,20 @@ TEST_CASE(const_of_shape1) ...@@ -946,20 +946,20 @@ TEST_CASE(const_of_shape1)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(const_of_shape2) TEST_CASE(const_of_shape_int64)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3}); migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4})); p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4}); migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4});
std::vector<int64_t> vec(s.elements(), 10.0f); std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec)); p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape2.onnx"); auto prog = migraphx::parse_onnx("const_of_shape2.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(const_of_shape3) TEST_CASE(const_of_shape_no_value_attr)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3}); migraphx::shape ss(migraphx::shape::int32_type, {3});
...@@ -972,7 +972,7 @@ TEST_CASE(const_of_shape3) ...@@ -972,7 +972,7 @@ TEST_CASE(const_of_shape3)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(const_of_shape4) TEST_CASE(const_of_shape_empty_input)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::literal()); p.add_literal(migraphx::literal());
......
:
0 Placeholder*
shape:*
dtype0

rsqrtRsqrt0*
T0"
\ No newline at end of file
...@@ -351,6 +351,16 @@ TEST_CASE(reshape_test) ...@@ -351,6 +351,16 @@ TEST_CASE(reshape_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(rsqrt_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::rsqrt{}, l0);
auto prog = optimize_tf("rsqrt_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -432,4 +442,16 @@ TEST_CASE(tanh_test) ...@@ -432,4 +442,16 @@ TEST_CASE(tanh_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(transpose_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}};
p.add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto prog = optimize_tf("transpose_test.pb", false);
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment