Commit 3a3ae8b5 authored by Scott Thornton's avatar Scott Thornton
Browse files

converted previous spatial bool to enum type in ONNX

parent e76bd729
......@@ -22,8 +22,6 @@ struct batch_norm_inference
{
float epsilon = 1.0e-6f;
float momentum = 0.9f;
bool spatial = true;
bool is_test = false;
std::string name() const { return "batch_norm_inference"; }
......@@ -35,6 +33,8 @@ struct batch_norm_inference
bn_infer_mode_t bn_mode = spatial;
bool is_test = false;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(5);
......
......@@ -171,10 +171,10 @@ struct onnx_parser
instruction_ref
parse_batchnorm(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
float epsilon = 1e-5f;
float momentum = 0.9f;
bool spatial = true;
bool is_test = false;
float epsilon = 1e-5f;
float momentum = 0.9f;
batch_norm_inference::bn_infer_mode_t bn_mode = batch_norm_inference::spatial;
bool is_test = false;
if(contains(attributes, "epsilon"))
{
epsilon = parse_value(attributes.at("epsilon")).at<float>();
......@@ -189,9 +189,11 @@ struct onnx_parser
}
if(contains(attributes, "spatial"))
{
spatial = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0) ? true : false;
bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0)
? batch_norm_inference::spatial
: batch_norm_inference::per_activation;
}
batch_norm_inference op{epsilon, momentum, spatial, is_test};
batch_norm_inference op{epsilon, momentum, bn_mode, is_test};
return prog.add_instruction(op, args);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment