"...composable_kernel_rocm.git" did not exist on "e5c55c1996ee5e1e7c9460e2ecc44ea5bb61a94b"
Unverified Commit a023ec19 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Add additional simple operators (MatMulInteger, ConvInteger, Asinh, Acosh, and Atanh (#431)



* Add initial api

* Formatting

* Add more api

* Formatting

* add more operators (asinh, acosh, atanh, MatMulInteger, ConvInteger)

* clang format

* add unit tests for new operators

* clang format

* Add auto api generation

* Formatting

* Fix some compilation errors

* Change handle struct

* Formatting

* Fix reamining compilation errors

* Formatting

* Simplify using ctype

* Formatting

* Initial c++ generation

* Formatting

* Add C++header

* Formatting

* Add test

* Formatting

* Add initial tests

* Formatting

* Try to fix formatting

* Cleanup formatting

* Formatting

* Fix constructors on the same line

* Fix tests

* Formatting

* Fix tidy issues

* Fix tidy issues

* Fix naming issue

* Add onnx API to parse buffer

* Formatting

* Add arguments api

* Formatting

* Fix verify parameters

* Fix cppcheck issues

* Formatting

* Add method to get output shapes and bytes

* Formatting

* Try formatting

* Formatting

* Improve the test coverage

* Formatting

* Add print method

* Formatting

* Fix cppcheck issue

* Fix package dependency

* Add nolint

* Try fix formatting

* Formatting

* formatting

* formatting

* Fix formatting
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>
parent b949da7f
#ifndef MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct acosh : unary<acosh>
{
auto apply() const
{
return [](auto x) { return std::acosh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct asinh : unary<asinh>
{
auto apply() const
{
return [](auto x) { return std::asinh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct atanh : unary<atanh>
{
auto apply() const
{
return [](auto x) { return std::atanh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -4,12 +4,15 @@ ...@@ -4,12 +4,15 @@
#include <migraphx/op/abnormal_ops.hpp> #include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/op/abs.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp> #include <migraphx/op/acos.hpp>
#include <migraphx/op/acosh.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp> #include <migraphx/op/asin.hpp>
#include <migraphx/op/asinh.hpp>
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp> #include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
......
...@@ -37,78 +37,72 @@ struct onnx_parser ...@@ -37,78 +37,72 @@ struct onnx_parser
onnx_parser() onnx_parser()
{ {
add_generic_op("Relu", op::relu{}); // sort onnx operator alphabetically through name
add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{}); add_generic_op("Abs", op::abs{});
add_generic_op("Exp", op::exp{}); add_generic_op("Acos", op::acos{});
add_generic_op("Acosh", op::acosh{});
add_generic_op("Asin", op::asin{});
add_generic_op("Asinh", op::asinh{});
add_generic_op("Atan", op::atan{});
add_generic_op("Atanh", op::atanh{});
add_generic_op("Ceil", op::ceil{});
add_generic_op("Cos", op::cos{});
add_generic_op("Cosh", op::cosh{});
add_generic_op("Erf", op::erf{}); add_generic_op("Erf", op::erf{});
add_generic_op("Log", op::log{}); add_generic_op("Exp", op::exp{});
// disable dropout for inference
add_generic_op("Dropout", op::identity{}); add_generic_op("Dropout", op::identity{});
add_generic_op("Log", op::log{});
add_generic_op("Floor", op::floor{});
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{});
add_generic_op("Round", op::round{});
add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Sign", op::sign{});
add_generic_op("Sin", op::sin{}); add_generic_op("Sin", op::sin{});
add_generic_op("Cos", op::cos{});
add_generic_op("Tan", op::tan{});
add_generic_op("Sinh", op::sinh{}); add_generic_op("Sinh", op::sinh{});
add_generic_op("Cosh", op::cosh{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("Asin", op::asin{});
add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{});
add_generic_op("Sqrt", op::sqrt{}); add_generic_op("Sqrt", op::sqrt{});
add_generic_op("Round", op::round{}); add_generic_op("Tan", op::tan{});
add_generic_op("Sign", op::sign{}); add_generic_op("Tanh", op::tanh{});
add_generic_op("Ceil", op::ceil{});
add_generic_op("Floor", op::floor{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{}); add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{});
add_binary_op("Pow", op::pow{}); add_binary_op("Pow", op::pow{});
add_binary_op("Sub", op::sub{});
add_variadic_op("Sum", op::add{}); add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>); add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>); add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Cast", &onnx_parser::parse_cast); add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose); add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("AveragePool", &onnx_parser::parse_pooling); add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten); add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("MatMul", &onnx_parser::parse_matmul); add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm); add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>); add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
add_mem_op("Concat", &onnx_parser::parse_concat); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1); add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2); add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum); add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
...@@ -119,6 +113,16 @@ struct onnx_parser ...@@ -119,6 +113,16 @@ struct onnx_parser
add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>); add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>); add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square); add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
// init the activation function map // init the activation function map
init_actv_func(); init_actv_func();
...@@ -414,10 +418,11 @@ struct onnx_parser ...@@ -414,10 +418,11 @@ struct onnx_parser
return ins; return ins;
} }
template <class Op>
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
op::convolution op; Op op;
auto l0 = args[0]; auto l0 = args[0];
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
...@@ -829,6 +834,7 @@ struct onnx_parser ...@@ -829,6 +834,7 @@ struct onnx_parser
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
} }
template <class Op>
instruction_ref instruction_ref
parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
...@@ -877,7 +883,7 @@ struct onnx_parser ...@@ -877,7 +883,7 @@ struct onnx_parser
} }
} }
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1); auto dot_res = prog.add_instruction(Op{1, 0}, bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended) if(is_a_prepended)
{ {
......
...@@ -12,6 +12,7 @@ endif() ...@@ -12,6 +12,7 @@ endif()
add_library(migraphx_device add_library(migraphx_device
device/acos.cpp device/acos.cpp
device/acosh.cpp
device/add.cpp device/add.cpp
device/add_clip.cpp device/add_clip.cpp
device/add_relu.cpp device/add_relu.cpp
...@@ -20,7 +21,9 @@ add_library(migraphx_device ...@@ -20,7 +21,9 @@ add_library(migraphx_device
device/argmax.cpp device/argmax.cpp
device/argmin.cpp device/argmin.cpp
device/asin.cpp device/asin.cpp
device/asinh.cpp
device/atan.cpp device/atan.cpp
device/atanh.cpp
device/ceil.cpp device/ceil.cpp
device/clip.cpp device/clip.cpp
device/concat.cpp device/concat.cpp
......
#include <migraphx/gpu/device/acosh.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void acosh(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::acosh(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/asinh.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void asinh(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::asinh(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/atanh.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void atanh(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::atanh(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_ACOSH_HPP
#define MIGRAPHX_GUARD_RTGLIB_ACOSH_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/acosh.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_acosh : unary_device<hip_acosh, device::acosh>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_ASINH_HPP
#define MIGRAPHX_GUARD_RTGLIB_ASINH_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/asinh.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_asinh : unary_device<hip_asinh, device::asinh>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_ATANH_HPP
#define MIGRAPHX_GUARD_RTGLIB_ATANH_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/atanh.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_atanh : unary_device<hip_atanh, device::atanh>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ACOSH_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ACOSH_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void acosh(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ASINH_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ASINH_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void asinh(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ATANH_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ATANH_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void atanh(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -42,6 +42,9 @@ ...@@ -42,6 +42,9 @@
#include <migraphx/gpu/asin.hpp> #include <migraphx/gpu/asin.hpp>
#include <migraphx/gpu/acos.hpp> #include <migraphx/gpu/acos.hpp>
#include <migraphx/gpu/atan.hpp> #include <migraphx/gpu/atan.hpp>
#include <migraphx/gpu/asinh.hpp>
#include <migraphx/gpu/acosh.hpp>
#include <migraphx/gpu/atanh.hpp>
#include <migraphx/gpu/mul.hpp> #include <migraphx/gpu/mul.hpp>
#include <migraphx/gpu/max.hpp> #include <migraphx/gpu/max.hpp>
#include <migraphx/gpu/min.hpp> #include <migraphx/gpu/min.hpp>
...@@ -121,6 +124,9 @@ struct miopen_apply ...@@ -121,6 +124,9 @@ struct miopen_apply
add_generic_op<hip_asin>("asin"); add_generic_op<hip_asin>("asin");
add_generic_op<hip_acos>("acos"); add_generic_op<hip_acos>("acos");
add_generic_op<hip_atan>("atan"); add_generic_op<hip_atan>("atan");
add_generic_op<hip_asinh>("asinh");
add_generic_op<hip_acosh>("acosh");
add_generic_op<hip_atanh>("atanh");
add_generic_op<hip_sqrt>("sqrt"); add_generic_op<hip_sqrt>("sqrt");
add_generic_op<hip_mul>("mul"); add_generic_op<hip_mul>("mul");
add_generic_op<hip_div>("div"); add_generic_op<hip_div>("div");
......
...@@ -779,6 +779,50 @@ TEST_CASE(atan_test) ...@@ -779,6 +779,50 @@ TEST_CASE(atan_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(asinh_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
std::vector<float> data{-0.5f, 0.0f, 0.9f};
auto l = p.add_literal(migraphx::literal{s, data});
p.add_instruction(migraphx::op::asinh{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.481211841, 0, 0.808866858};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(acosh_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {3}};
std::vector<float> data{1.1f, 1.2f, 2.0f};
auto l = p.add_literal(migraphx::literal{s, data});
p.add_instruction(migraphx::op::acosh{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.4435683, 0.6223626, 1.316958};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(atanh_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {3}};
auto l = p.add_literal(migraphx::literal{s, {0.4435683, 0.6223626, 0.316958}});
p.add_instruction(migraphx::op::atanh{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.476664424, 0.728852153, 0.328261733};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(add_test) TEST_CASE(add_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -451,6 +451,44 @@ struct test_atan : verify_program<test_atan> ...@@ -451,6 +451,44 @@ struct test_atan : verify_program<test_atan>
} }
}; };
struct test_asinh : verify_program<test_asinh>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {16}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::asinh{}, x);
return p;
}
};
struct test_acosh : verify_program<test_acosh>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = p.add_parameter("x", s);
auto cx = p.add_instruction(migraphx::op::clip{100.0f, 1.1f}, x);
p.add_instruction(migraphx::op::acosh{}, cx);
return p;
}
};
struct test_atanh : verify_program<test_atanh>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {16}};
auto x = p.add_parameter("x", s);
auto cx = p.add_instruction(migraphx::op::clip{0.95f, -0.95f}, x);
p.add_instruction(migraphx::op::atanh{}, cx);
return p;
}
};
struct test_scale : verify_program<test_scale> struct test_scale : verify_program<test_scale>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......

acosh_test:=
xy"Acosh
acosh_testZ
x


b
y


B
\ No newline at end of file

asinh_test:=
xy"Asinh
asinh_testZ
x


b
y


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment