Commit 8ea9fc25 authored by Khalique's avatar Khalique
Browse files

formatting

parent 9cf50769
...@@ -391,9 +391,8 @@ struct onnx_parser ...@@ -391,9 +391,8 @@ struct onnx_parser
return prog.add_instruction(op, args.front()); return prog.add_instruction(op, args.front());
} }
instruction_ref parse_elu(const std::string&, instruction_ref
attribute_map attributes, parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
float alpha = 1.0; // default alpha val for elu float alpha = 1.0; // default alpha val for elu
if(contains(attributes, "alpha")) if(contains(attributes, "alpha"))
......
...@@ -19,9 +19,10 @@ T zero(const T&) ...@@ -19,9 +19,10 @@ T zero(const T&)
return T(0); return T(0);
} }
template<class T> template <class T>
typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::enable_if<true, T>>::type typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::enable_if<true, T>>::
make_signed(T x) type
make_signed(T x)
{ {
return x; return x;
} }
...@@ -617,7 +618,7 @@ struct cpu_apply ...@@ -617,7 +618,7 @@ struct cpu_apply
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>(); apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>(); apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>(); apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>(); apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
apply_map["abs"] = simple_op<cpu_unary<abs_op>>(); apply_map["abs"] = simple_op<cpu_unary<abs_op>>();
apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>(); apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>();
......
#include <migraphx/gpu/elu.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
shape miopen_elu::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_elu::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(ctx.get_stream().get_miopen(),
ad.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPH_GUARD_RTGLIB_ELU_HPP
#define MIGRAPH_GUARD_RTGLIB_ELU_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
struct miopen_elu
{
shared<activation_descriptor> ad;
std::string name() const { return "gpu::elu"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraphx
#endif
...@@ -7,15 +7,9 @@ ...@@ -7,15 +7,9 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include "test.hpp" #include "test.hpp"
float sigmoid(float x) float sigmoid(float x) { return 1 / (1 + expf(-x)); }
{
return 1 / (1 + expf(-x));
}
float elu(float a, float x) float elu(float a, float x) { return x > 0 ? x : a * std::expm1(x); }
{
return x > 0 ? x : a * std::expm1(x);
}
TEST_CASE(slice_test) TEST_CASE(slice_test)
{ {
...@@ -1161,14 +1155,14 @@ TEST_CASE(elu_test) ...@@ -1161,14 +1155,14 @@ TEST_CASE(elu_test)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l = p.add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}}); auto l = p.add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}});
float alpha = 0.5; float alpha = 0.5;
p.add_instruction(migraphx::op::elu{alpha}, l); p.add_instruction(migraphx::op::elu{alpha}, l);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{elu(alpha,-1), elu(alpha,2), elu(alpha,-3), elu(alpha,4)}; std::vector<float> gold{elu(alpha, -1), elu(alpha, 2), elu(alpha, -3), elu(alpha, 4)};
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
......
...@@ -449,7 +449,7 @@ struct test_sigmoid ...@@ -449,7 +449,7 @@ struct test_sigmoid
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraphx::op::sigmoid{}, x); p.add_instruction(migraphx::op::sigmoid{}, x);
return p; return p;
} }
...@@ -460,7 +460,7 @@ struct test_tanh ...@@ -460,7 +460,7 @@ struct test_tanh
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraphx::op::tanh{}, x); p.add_instruction(migraphx::op::tanh{}, x);
return p; return p;
} }
...@@ -471,7 +471,7 @@ struct test_abs ...@@ -471,7 +471,7 @@ struct test_abs
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraphx::op::abs{}, x); p.add_instruction(migraphx::op::abs{}, x);
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment