Commit 497383c8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code cleanup according to comments from Paul.

parents 13d2a743 8885e8ac
...@@ -60,31 +60,31 @@ struct miopen_apply ...@@ -60,31 +60,31 @@ struct miopen_apply
void init() void init()
{ {
add_miopen_simple_op("relu", miopen_relu{}, make_relu); add_miopen_simple_op<miopen_relu>("relu", make_relu);
add_miopen_simple_op("sigmoid", miopen_sigmoid{}, make_sigmoid); add_miopen_simple_op<miopen_sigmoid>("sigmoid", make_sigmoid);
add_miopen_simple_op("abs", miopen_abs{}, make_abs); add_miopen_simple_op<miopen_abs>("abs", make_abs);
add_miopen_simple_op("tanh", miopen_tanh{}, make_tanh); add_miopen_simple_op<miopen_tanh>("tanh", make_tanh);
add_miopen_extend_op("leaky_relu", miopen_leaky_relu{}, op::leaky_relu{}, make_leaky_relu); add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu);
add_miopen_extend_op("elu", miopen_elu{}, op::elu{}, make_elu); add_miopen_extend_op<miopen_elu, op::elu>("elu", make_elu);
add_generic_op("add", hip_add{}); add_generic_op<hip_add>("add");
add_generic_op("exp", hip_exp{}); add_generic_op<hip_exp>("exp");
add_generic_op("log", hip_log{}); add_generic_op<hip_log>("log");
add_generic_op("sin", hip_sin{}); add_generic_op<hip_sin>("sin");
add_generic_op("cos", hip_cos{}); add_generic_op<hip_cos>("cos");
add_generic_op("tan", hip_tan{}); add_generic_op<hip_tan>("tan");
add_generic_op("sinh", hip_sinh{}); add_generic_op<hip_sinh>("sinh");
add_generic_op("cosh", hip_cosh{}); add_generic_op<hip_cosh>("cosh");
add_generic_op("asin", hip_asin{}); add_generic_op<hip_asin>("asin");
add_generic_op("acos", hip_acos{}); add_generic_op<hip_acos>("acos");
add_generic_op("atan", hip_atan{}); add_generic_op<hip_atan>("atan");
add_generic_op("mul", hip_mul{}); add_generic_op<hip_mul>("mul");
add_extend_op("dot", miopen_gemm{}, op::dot{}); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op("contiguous", miopen_contiguous{}, op::contiguous{}); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op("concat", hip_concat{}, op::concat{}); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op("softmax", miopen_softmax{}, op::softmax{}); add_extend_op<miopen_softmax, op::softmax>("softmax");
add_convolution_op(); add_convolution_op();
add_pooling_op(); add_pooling_op();
...@@ -147,7 +147,7 @@ struct miopen_apply ...@@ -147,7 +147,7 @@ struct miopen_apply
} }
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
...@@ -156,11 +156,10 @@ struct miopen_apply ...@@ -156,11 +156,10 @@ struct miopen_apply
return prog->replace_instruction(ins, T{}, refs); return prog->replace_instruction(ins, T{}, refs);
}); });
(void)x;
} }
template <class T, class Op> template <class T, class Op>
void add_extend_op(std::string name, T x, Op o) void add_extend_op(std::string name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<Op>(ins->get_operator()); auto&& op = any_cast<Op>(ins->get_operator());
...@@ -170,12 +169,10 @@ struct miopen_apply ...@@ -170,12 +169,10 @@ struct miopen_apply
return prog->replace_instruction(ins, T{op}, refs); return prog->replace_instruction(ins, T{op}, refs);
}); });
(void)x;
(void)o;
} }
template <class T, class Op, class F> template <class T, class Op, class F>
void add_miopen_extend_op(std::string name, T x, Op o, F f) void add_miopen_extend_op(std::string name, F f)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<Op>(ins->get_operator()); auto&& op = any_cast<Op>(ins->get_operator());
...@@ -184,21 +181,16 @@ struct miopen_apply ...@@ -184,21 +181,16 @@ struct miopen_apply
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, T{std::move(ad)}, ins->inputs().at(0), output); return prog->replace_instruction(ins, T{std::move(ad)}, ins->inputs().at(0), output);
}); });
(void)x;
(void)o;
(void)f;
} }
template <class T, class F> template <class T, class F>
void add_miopen_simple_op(std::string name, T x, F f) void add_miopen_simple_op(std::string name, F f)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
auto ad = f(); auto ad = f();
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, T{std::move(ad)}, ins->inputs().at(0), output); return prog->replace_instruction(ins, T{std::move(ad)}, ins->inputs().at(0), output);
}); });
(void)x;
(void)f;
} }
void add_batch_norm_inference_op() void add_batch_norm_inference_op()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment