Commit 33212f8f authored by Paul's avatar Paul
Browse files

Add onnx updates from resnet branch

parent 4d0fdcd5
......@@ -29,7 +29,7 @@ struct unknown
}
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW(name() + ": not computable");
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
......@@ -60,8 +60,11 @@ struct onnx_parser
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("MaxPool", &onnx_parser::parse_max_pooling);
add_mem_op("AveragePool", &onnx_parser::parse_average_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
}
......@@ -126,7 +129,7 @@ struct onnx_parser
}
instruction_ref
parse_pooling(std::string, attribute_map attributes, std::vector<instruction_ref> args)
parse_max_pooling(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
pooling op{"max"};
if(contains(attributes, "pads"))
......@@ -144,6 +147,25 @@ struct onnx_parser
return prog.add_instruction(op, args);
}
instruction_ref
parse_average_pooling(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
pooling op{"average"};
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
}
if(contains(attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "kernel_shape"))
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
return prog.add_instruction(op, args);
}
instruction_ref
parse_reshape(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -161,6 +183,17 @@ struct onnx_parser
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_flatten(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
uint64_t axis = 0;
// if(contains(attributes, "axis"))
// {
// axis = parse_value(attributes.at("axis")).at<int>();
// }
return prog.add_instruction(flatten{axis}, args[0]);
}
instruction_ref
parse_constant(std::string, attribute_map attributes, std::vector<instruction_ref>)
{
......@@ -168,6 +201,42 @@ struct onnx_parser
return prog.add_literal(v);
}
instruction_ref
parse_gemm(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 1.0f;
float beta = 0.0f;
bool transa = false;
bool transb = false;
if(contains(attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
}
if(contains(attributes, "beta"))
{
alpha = parse_value(attributes.at("beta")).at<float>();
}
if(contains(attributes, "transA"))
{
transa = parse_value(attributes.at("transA")).at<bool>();
}
if(contains(attributes, "transB"))
{
transb = parse_value(attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm = {1, 0};
auto l1 = (transa) ? prog.add_instruction(transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(transpose{perm}, args[1]) : args[1];
if(args.size() == 3)
{
uint64_t axis = 1;
auto l3 = prog.add_instruction(gemm{alpha, beta}, l1, l2);
auto l4 = prog.add_instruction(broadcast{axis}, l3, args[2]);
return prog.add_instruction(add{}, l3, l4);
}
return prog.add_instruction(gemm{alpha, beta}, l1, l2);
}
instruction_ref
parse_batchnorm(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment