Commit 70320bbd authored by Shucai Xiao's avatar Shucai Xiao
Browse files

lowering fp16 lrn to fp32 lrn on the gpu target

parent 4f9a0ce7
...@@ -164,7 +164,6 @@ struct miopen_apply ...@@ -164,7 +164,6 @@ struct miopen_apply
add_extend_op("gather"); add_extend_op("gather");
add_extend_op("leaky_relu"); add_extend_op("leaky_relu");
add_extend_op("logsoftmax"); add_extend_op("logsoftmax");
add_extend_op("lrn");
add_extend_op("multinomial"); add_extend_op("multinomial");
add_extend_op("nonzero"); add_extend_op("nonzero");
add_extend_op("pad"); add_extend_op("pad");
...@@ -192,6 +191,7 @@ struct miopen_apply ...@@ -192,6 +191,7 @@ struct miopen_apply
add_gemm_op<op::quant_dot>("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
add_if_op(); add_if_op();
add_loop_op(); add_loop_op();
add_lrn_op();
add_neg_op(); add_neg_op();
add_nms_op(); add_nms_op();
add_quant_convolution_op(); add_quant_convolution_op();
...@@ -562,6 +562,36 @@ struct miopen_apply ...@@ -562,6 +562,36 @@ struct miopen_apply
return mod->replace_instruction(ins, gpu_out); return mod->replace_instruction(ins, gpu_out);
}); });
} }
void add_lrn_op()
{
apply_map.emplace("lrn", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto in = ins->inputs().front();
auto output = insert_allocation(ins, s);
auto type = s.type();
if(type == shape::half_type)
{
shape s32{shape::float_type, s.lens()};
auto cout32 = insert_allocation(ins, s32);
auto cop32 = make_op("convert", {{"target_type", shape::float_type}});
auto convert32 = mod->insert_instruction(ins, make_op("gpu::convert", cop32.to_value()), in, cout32);
auto lout32 = insert_allocation(ins, s32);
auto lrn32 = mod->insert_instruction(ins, make_op("gpu::lrn", ins->get_operator().to_value()), convert32, lout32);
auto cop16 = make_op("convert", {{"target_type", shape::half_type}});
auto lout16 = mod->insert_instruction(ins, make_op("gpu::convert", cop16.to_value()), lrn32, output);
return mod->replace_instruction(ins, lout16);
}
else
{
auto lrn16 = mod->insert_instruction(ins, make_op("gpu::lrn", ins->get_operator().to_value()), in, output);
return mod->replace_instruction(ins, lrn16);
}
});
}
}; };
void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); } void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment