Commit 3a2469f4 authored by Scott Thornton's avatar Scott Thornton
Browse files

added ONNX parser for batchnorm

parent 415476ae
......@@ -105,7 +105,10 @@ struct not_computable
struct batch_norm_inference
{
double epsilon = 1.0e-6;
float epsilon = 1.0e-6f;
float momentum = 0.9f;
bool spatial = true;
bool is_test = false;
std::string name() const { return "batch_norm_inference"; }
......
......@@ -62,6 +62,7 @@ struct onnx_parser
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
}
template <class F>
......@@ -167,6 +168,33 @@ struct onnx_parser
return prog.add_literal(v);
}
instruction_ref
parse_batchnorm(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
float epsilon = 1e-5f;
float momentum = 0.9f;
bool spatial = true;
bool is_test = false;
if(contains(attributes, "epsilon"))
{
epsilon = parse_value(attributes.at("epsilon")).at<float>();
}
if(contains(attributes, "momentum"))
{
epsilon = parse_value(attributes.at("momentum")).at<float>();
}
if(contains(attributes, "is_test"))
{
is_test = (parse_value(attributes.at("is_test")).at<uint64_t>() > 0) ? true : false;
}
if(contains(attributes, "spatial"))
{
spatial = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0) ? true : false;
}
batch_norm_inference op{epsilon, momentum, spatial, is_test};
return prog.add_instruction(op, args);
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment