"src/targets/gpu/jit/softmax.cpp" did not exist on "27e980c4058690c3ab1376d055eef42e6a5ebf0a"
Unverified Commit b4d2a740 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #28 from ROCmSoftwarePlatform/batchnorm_onnx

Batchnorm onnx
parents 6bb6b72e a8dd3210
......@@ -20,7 +20,8 @@ struct not_computable
struct batch_norm_inference
{
double epsilon = 1.0e-6;
float epsilon = 1.0e-6f;
float momentum = 0.9f;
std::string name() const { return "batch_norm_inference"; }
......@@ -32,6 +33,8 @@ struct batch_norm_inference
bn_infer_mode_t bn_mode = spatial;
bool is_test = false;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(5);
......
......@@ -62,6 +62,7 @@ struct onnx_parser
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
}
template <class F>
......@@ -167,6 +168,35 @@ struct onnx_parser
return prog.add_literal(v);
}
instruction_ref
parse_batchnorm(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
float epsilon = 1e-5f;
float momentum = 0.9f;
batch_norm_inference::bn_infer_mode_t bn_mode = batch_norm_inference::spatial;
bool is_test = false;
if(contains(attributes, "epsilon"))
{
epsilon = parse_value(attributes.at("epsilon")).at<float>();
}
if(contains(attributes, "momentum"))
{
epsilon = parse_value(attributes.at("momentum")).at<float>();
}
if(contains(attributes, "is_test"))
{
is_test = parse_value(attributes.at("is_test")).at<uint64_t>() > 0;
}
if(contains(attributes, "spatial"))
{
bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0)
? batch_norm_inference::spatial
: batch_norm_inference::per_activation;
}
batch_norm_inference op{epsilon, momentum, bn_mode, is_test};
return prog.add_instruction(op, args);
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......
......@@ -39,6 +39,29 @@ void pytorch_conv_relu_maxpool()
EXPECT(p == prog);
}
void pytorch_conv_bn_relu_maxpool()
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
auto p3 = p.add_parameter("3", {migraph::shape::float_type, {1}});
auto p4 = p.add_parameter("4", {migraph::shape::float_type, {1}});
auto p5 = p.add_parameter("5", {migraph::shape::float_type, {1}});
auto p6 = p.add_parameter("6", {migraph::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2);
auto l5 = p.add_instruction(migraph::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::batch_norm_inference{}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::activation{"relu"}, l6);
p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto prog = migraph::parse_onnx("conv_bn_relu_maxpool.onnx");
EXPECT(p == prog);
}
void pytorch_conv_relu_maxpoolX2()
{
migraph::program p;
......@@ -69,5 +92,6 @@ int main()
{
pytorch_conv_bias_test();
pytorch_conv_relu_maxpool();
pytorch_conv_bn_relu_maxpool();
pytorch_conv_relu_maxpoolX2();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment