"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "1f9eb6c8f13ce3fa04a5b7c59a729848886d00c5"
Commit 00b6f5f8 authored by Khalique's avatar Khalique
Browse files

changed if statement to use apply map

parent 0ad4a708
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/concat.hpp> #include <migraphx/gpu/concat.hpp>
#include <utility> #include <utility>
#include <functional>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPH_INLINE_NS {
...@@ -39,6 +40,7 @@ struct miopen_apply ...@@ -39,6 +40,7 @@ struct miopen_apply
{ {
program* prog = nullptr; program* prog = nullptr;
context ctx{}; context ctx{};
std::unordered_map<std::string, std::function<instruction_ref(miopen_apply&, instruction_ref)>> apply_map{};
void check_shape(shape x, instruction_ref i) void check_shape(shape x, instruction_ref i)
{ {
...@@ -47,74 +49,35 @@ struct miopen_apply ...@@ -47,74 +49,35 @@ struct miopen_apply
(void)i; (void)i;
} }
void init()
{
apply_map["convolution"] = &miopen_apply::apply_convolution;
apply_map["relu"] = &miopen_apply::apply_relu;
apply_map["sigmoid"] = &miopen_apply::apply_sigmoid;
apply_map["tanh"] = &miopen_apply::apply_tanh;
apply_map["abs"] = &miopen_apply::apply_abs;
apply_map["leaky_relu"] = &miopen_apply::apply_leaky_relu;
apply_map["elu"] = &miopen_apply::apply_elu;
apply_map["pooling"] = &miopen_apply::apply_pooling;
apply_map["add"] = &miopen_apply::apply_add;
apply_map["sin"] = &miopen_apply::apply_sin;
apply_map["mul"] = &miopen_apply::apply_mul;
apply_map["dot"] = &miopen_apply::apply_dot;
apply_map["contiguous"] = &miopen_apply::apply_contiguous;
apply_map["concat"] = &miopen_apply::apply_concat;
apply_map["batch_norm_inference"] = &miopen_apply::apply_batch_norm_inference;
apply_map["softmax"] = &miopen_apply::apply_softmax;
}
void apply() void apply()
{ {
init();
for(auto it = prog->begin(); it != prog->end(); it++) for(auto it = prog->begin(); it != prog->end(); it++)
{ {
auto s = it->get_shape(); auto s = it->get_shape();
if(it->name() == "convolution") if(apply_map.count(it->name()) > 0)
{
check_shape(s, apply_convolution(it));
}
else if(it->name() == "relu")
{
check_shape(s, apply_relu(it));
}
else if(it->name() == "sigmoid")
{
check_shape(s, apply_sigmoid(it));
}
else if(it->name() == "tanh")
{
check_shape(s, apply_tanh(it));
}
else if(it->name() == "abs")
{
check_shape(s, apply_abs(it));
}
else if(it->name() == "leaky_relu")
{
check_shape(s, apply_leaky_relu(it));
}
else if(it->name() == "elu")
{
check_shape(s, apply_elu(it));
}
else if(it->name() == "pooling")
{
check_shape(s, apply_pooling(it));
}
else if(it->name() == "add")
{
check_shape(s, apply_add(it));
}
else if(it->name() == "sin")
{
check_shape(s, apply_sin(it));
}
else if(it->name() == "mul")
{
check_shape(s, apply_mul(it));
}
else if(it->name() == "dot")
{
check_shape(s, apply_gemm(it));
}
else if(it->name() == "contiguous")
{
check_shape(s, apply_contiguous(it));
}
else if(it->name() == "concat")
{
check_shape(s, apply_concat(it));
}
else if(it->name() == "batch_norm_inference")
{
check_shape(s, apply_batch_norm_inference(it));
}
else if(it->name() == "softmax")
{ {
check_shape(s, apply_softmax(it)); check_shape(s, apply_map.at(it->name())(*this, it));
} }
} }
} }
...@@ -240,7 +203,7 @@ struct miopen_apply ...@@ -240,7 +203,7 @@ struct miopen_apply
ins, hip_mul{}, ins->inputs().at(0), ins->inputs().at(1), output); ins, hip_mul{}, ins->inputs().at(0), ins->inputs().at(1), output);
} }
instruction_ref apply_gemm(instruction_ref ins) instruction_ref apply_dot(instruction_ref ins)
{ {
auto&& op = any_cast<op::dot>(ins->get_operator()); auto&& op = any_cast<op::dot>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment