Commit 05e8bfde authored by turneram's avatar turneram
Browse files

Formatting

parent 3ea9fe4c
...@@ -19,7 +19,7 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -19,7 +19,7 @@ struct parse_layernorm : op_parser<parse_layernorm>
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
float epsilon = 1e-3f; float epsilon = 1e-3f;
int64_t axis = -1; int64_t axis = -1;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
...@@ -29,14 +29,15 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -29,14 +29,15 @@ struct parse_layernorm : op_parser<parse_layernorm>
epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>(); epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
} }
auto layernorm = info.add_instruction(make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front()); auto layernorm = info.add_instruction(
make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front());
if (args.size() == 3) if(args.size() == 3)
{ {
layernorm = info.add_broadcastable_binary_op("mul", layernorm, args.at(1)); layernorm = info.add_broadcastable_binary_op("mul", layernorm, args.at(1));
layernorm = info.add_broadcastable_binary_op("add", layernorm, args.at(2)); layernorm = info.add_broadcastable_binary_op("add", layernorm, args.at(2));
} }
return layernorm; return layernorm;
} }
}; };
......
...@@ -9,9 +9,10 @@ struct test_layernorm : verify_program<test_layernorm> ...@@ -9,9 +9,10 @@ struct test_layernorm : verify_program<test_layernorm>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 384, 768}}); auto x =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 384, 768}});
mm->add_instruction(migraphx::make_op("layernorm", {{"axis", -1}}), x); mm->add_instruction(migraphx::make_op("layernorm", {{"axis", -1}}), x);
return p; return p;
} }
}; };
\ No newline at end of file
...@@ -9,9 +9,10 @@ struct test_transposectx : verify_program<test_transposectx> ...@@ -9,9 +9,10 @@ struct test_transposectx : verify_program<test_transposectx>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 12, 128, 64}}); auto x =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 12, 128, 64}});
mm->add_instruction(migraphx::make_op("transposectx"), x); mm->add_instruction(migraphx::make_op("transposectx"), x);
return p; return p;
} }
}; };
...@@ -9,9 +9,10 @@ struct test_transposeqkv : verify_program<test_transposeqkv> ...@@ -9,9 +9,10 @@ struct test_transposeqkv : verify_program<test_transposeqkv>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 384, 3, 12, 64}}); auto x = mm->add_parameter(
"x", migraphx::shape{migraphx::shape::float_type, {2, 384, 3, 12, 64}});
mm->add_instruction(migraphx::make_op("transposeqkv"), x); mm->add_instruction(migraphx::make_op("transposeqkv"), x);
return p; return p;
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment