Commit c4ed136c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Use template for apply_* function implementation.

parent 123f7a01
......@@ -60,10 +60,10 @@ struct miopen_apply
apply_map["leaky_relu"] = &miopen_apply::apply_leaky_relu;
apply_map["elu"] = &miopen_apply::apply_elu;
apply_map["pooling"] = &miopen_apply::apply_pooling;
apply_map["add"] = &miopen_apply::apply_add;
apply_map["sin"] = &miopen_apply::apply_sin;
apply_map["mul"] = &miopen_apply::apply_mul;
apply_map["dot"] = &miopen_apply::apply_dot;
apply_map["add"] = &miopen_apply::apply_generic_op<hip_add>;
apply_map["sin"] = &miopen_apply::apply_generic_op<hip_sin>;
apply_map["mul"] = &miopen_apply::apply_generic_op<hip_mul>;
apply_map["dot"] = &miopen_apply::apply_generic_op<miopen_gemm>;
apply_map["contiguous"] = &miopen_apply::apply_contiguous;
apply_map["concat"] = &miopen_apply::apply_concat;
apply_map["batch_norm_inference"] = &miopen_apply::apply_batch_norm_inference;
......@@ -184,26 +184,42 @@ struct miopen_apply
return prog->replace_instruction(ins, miopen_softmax{op}, ins->inputs().at(0), output);
}
/*
instruction_ref apply_add(instruction_ref ins)
{
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, hip_add{}, ins->inputs().at(0), ins->inputs().at(1), output);
}
*/
/*
instruction_ref apply_sin(instruction_ref ins)
{
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, hip_sin{}, ins->inputs().at(0), output);
}
*/
template<class T>
instruction_ref apply_generic_op(instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return prog->replace_instruction(ins, T{}, refs);
}
/*
instruction_ref apply_mul(instruction_ref ins)
{
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, hip_mul{}, ins->inputs().at(0), ins->inputs().at(1), output);
}
*/
/*
instruction_ref apply_dot(instruction_ref ins)
{
auto&& op = any_cast<op::dot>(ins->get_operator());
......@@ -211,6 +227,7 @@ struct miopen_apply
return prog->replace_instruction(
ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output);
}
*/
instruction_ref apply_contiguous(instruction_ref ins)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment