Commit d14dd144 authored by Aditya Atluri's avatar Aditya Atluri
Browse files

added batch norm inference for miopen

parent d1481b13
...@@ -15,6 +15,48 @@ ...@@ -15,6 +15,48 @@
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
struct miopen_batch_norm_inference
{
batch_norm_inference op;
std::string name() const { return "gpu::batch_norm_inference"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(6);
return op.compute_shape(
{inputs.at(0), inputs.at(1), inputs.at(2), inputs.at(3), inputs.at(4)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
float alpha = 1.0, beta = 0.0f;
// TODO: adityaatluri
// create bn-scale-bias-mean-variance descriptor for
// miopen call
miopenBatchNormalizationForwardInference(ctx.handle.get(),
miopenBatchNormMode_t(op.bn_mode),
&alpha,
&beta,
x_desc.get(),
args[0].implicit(),
y_desc.get(),
args[5].implicit(),
bn_desc,
args[3].implicit(),
args[4].implicit(),
args[1].implicit(),
args[2].implicit(),
op.mode.epsilon);
return args[5];
}
};
struct miopen_convolution struct miopen_convolution
{ {
convolution op; convolution op;
...@@ -259,6 +301,12 @@ struct miopen_apply ...@@ -259,6 +301,12 @@ struct miopen_apply
{ {
apply_contiguous(it); apply_contiguous(it);
} }
// TODO: adityaatluri
// tagging to easily find where code changed
else if(it->op.name() == "batch_norm_inference")
{
apply_batch_norm_inference(it);
}
} }
} }
...@@ -332,6 +380,16 @@ struct miopen_apply ...@@ -332,6 +380,16 @@ struct miopen_apply
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins, miopen_contiguous{op}, ins->arguments.at(0), output); prog->replace_instruction(ins, miopen_contiguous{op}, ins->arguments.at(0), output);
} }
// TODO: adityaatluri
// Not sure how to write this. Review and fix required
void apply_batch_norm_inference(instruction_ref ins)
{
auto&& op = any_cast<batch_norm_inference>(ins->op);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(
ins, miopen_batch_norm_inference{op}, ins->arguments.at(0), output);
}
}; };
void lowering::apply(program& p) const { miopen_apply{&p}.apply(); } void lowering::apply(program& p) const { miopen_apply{&p}.apply(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment